Source code for nequip.model.param_groups
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
def _normalize_weight_index_slices(weight_index_slices):
normalized = []
for entry in weight_index_slices:
index_slice = getattr(entry, "slice_1D", None)
shape_2d = getattr(entry, "shape_2D", None)
if index_slice is None or shape_2d is None:
index_slice, shape_2d = entry
if isinstance(index_slice, slice):
index_slice = (index_slice.start, index_slice.stop, index_slice.step)
else:
index_slice = tuple(index_slice)
assert len(index_slice) == 3
shape_2d = tuple(shape_2d)
assert len(shape_2d) == 2
normalized.append((index_slice, shape_2d))
return normalized
[docs]
def MuonParamGroups(
model: torch.nn.Module,
muon: dict,
adam: dict,
):
"""
Build optimizer parameter groups, splitting parameters between a Muon-based optimizer
and Adam (or Adam-like) optimizer.
Assigned to Adam group:
- Any parameter whose name does **not** contain the substring ``"layer"``.
- Any parameter not matching the Muon-specific rules below.
Assigned to Muon group:
- Edge MLP weights: parameters whose name contains ``"edge_mlp"`` and that are
2D tensors (i.e., matrix weights).
- e3nn convolution linear weights: parameters whose name contains ``"conv.linear"``.
For e3nn ``Linear`` layers, the returned Muon parameter group includes an
``e3nn_reshaping`` dictionary mapping the index of the parameter within the Muon
group to the module's ``weight_index_slices``. This metadata will be used by the
to reshape or operate on corresponding matrix weights.
Args:
model (torch.nn.Module): The model to optimize.
muon (dict): Muon config parameters.
adam (dict): Adam config parameters.
"""
muon_weights = []
adam_weights = []
e3nn_reshaping = {}
modules = dict(model.named_modules())
for name, param in model.named_parameters():
# Assumes all input and output layers are
# not called layers.
if "layer" not in name:
adam_weights.append(param)
continue
# First, all edge_mlps should be muon
if "edge_mlp" in name and param.ndim == 2:
muon_weights.append(param)
continue
if "conv.linear" in name:
# e3nn conv layers.
# Find the e3nn Linear module this represents
module_name, _, _ = name.rpartition(".")
module = modules[module_name]
# use Muon only when reshape metadata is available
weight_index_slices = getattr(module, "weight_index_slices", None)
if weight_index_slices is None:
adam_weights.append(param)
continue
# store plain tuples to keep optimizer state picklable
index = len(muon_weights)
e3nn_reshaping[index] = _normalize_weight_index_slices(weight_index_slices)
muon_weights.append(param)
continue
adam_weights.append(param)
param_groups = [
dict(params=muon_weights, use_muon=True, e3nn_reshaping=e3nn_reshaping, **muon),
dict(params=adam_weights, use_muon=False, **adam),
]
return param_groups