Source code for nequip.train.simple_ddp

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from lightning.pytorch.strategies import DDPStrategy


[docs] class SimpleDDPStrategy(DDPStrategy): """Effectively Lightning's :class:`~lightning.pytorch.strategies.DDPStrategy`, but doing manual gradient syncs instead of using PyTorch's :class:`~torch.nn.parallel.DistributedDataParallel` wrapper. .. note:: To use train-time compilation with multi-rank training, this strategy must be used in place of PyTorch Lightning's :class:`~lightning.pytorch.strategies.DDPStrategy`. Example use in the config file: .. code-block:: yaml trainer: _target_: lightning.Trainer # other trainer arguments strategy: _target_: nequip.train.SimpleDDPStrategy """ def configure_ddp(self) -> None: pass
[docs] def post_backward(self, closure_loss: torch.Tensor) -> None: """ Manual syncing of gradients after the backwards pass. """ # cat all gradients into a single tensor for efficiency grad_tensors = [] for param in self.model.parameters(): if param.requires_grad and param.grad is not None: grad_tensors.append(param.grad.data.view(-1)) if grad_tensors: # cat and reduce flat_grads = torch.cat(grad_tensors) # NOTE: averaging (i.e. summing and dividing by number of ranks) is consistent with PyTorch Lightning's `DDPStrategy` # in the training loop, we account for this by multiplying the loss by the number of ranks before the backwards call if torch.distributed.get_backend() == "gloo": torch.distributed.all_reduce( flat_grads, op=torch.distributed.ReduceOp.SUM ) flat_grads /= torch.distributed.get_world_size() else: torch.distributed.all_reduce( flat_grads, op=torch.distributed.ReduceOp.AVG ) # copy reduced gradients back offset = 0 for param in self.model.parameters(): if param.requires_grad and param.grad is not None: numel = param.grad.numel() param.grad.data.copy_( flat_grads[offset : offset + numel].view_as(param.grad.data) ) offset += numel