nequip.train.SimpleDDPStrategy

class nequip.train.SimpleDDPStrategy(accelerator: Accelerator | None = None, parallel_devices: list[device] | None = None, cluster_environment: ClusterEnvironment | None = None, checkpoint_io: CheckpointIO | None = None, precision_plugin: Precision | None = None, ddp_comm_state: object | None = None, ddp_comm_hook: Callable | None = None, ddp_comm_wrapper: Callable | None = None, model_averaging_period: int | None = None, process_group_backend: str | None = None, timeout: timedelta | None = datetime.timedelta(seconds=1800), start_method: Literal['popen', 'spawn', 'fork', 'forkserver'] = 'popen', **kwargs: Any)[source]

Effectively Lightning’s DDPStrategy, but doing manual gradient syncs instead of using PyTorch’s DistributedDataParallel wrapper.

Note

To use train-time compilation with multi-rank training, this strategy must be used in place of PyTorch Lightning’s DDPStrategy.

Example use in the config file:

trainer:
  _target_: lightning.Trainer
  # other trainer arguments
  strategy:
    _target_: nequip.train.SimpleDDPStrategy
post_backward(closure_loss: Tensor) None[source]

Manual syncing of gradients after the backwards pass.