deepmd.pt.optimizer

deepmd.pt.optimizer#

Submodules#

Classes#

AdaMuonOptimizer

AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam.

HybridMuonOptimizer

HybridMuon optimizer with 1D Adam path and matrix Muon path.

KFOptimizerWrapper

LKFOptimizer

Package Contents#

class deepmd.pt.optimizer.AdaMuonOptimizer(params: collections.abc.Iterable[torch.Tensor] | collections.abc.Iterable[dict[str, Any]], lr: float = 0.001, momentum: float = 0.95, weight_decay: float = 0.001, ns_steps: int = 5, adam_betas: tuple[float, float] = (0.9, 0.95), adam_eps: float = 1e-08, nesterov: bool = True, lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, eps: float = 1e-08)[source]#

Bases: torch.optim.optimizer.Optimizer

AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam.

This optimizer applies different update rules based on parameter dimensionality: - For 2D+ parameters (weight matrices): AdaMuon update with sign-stabilized

Newton-Schulz orthogonalization and per-element v_buffer normalization.

  • For 1D parameters (biases, layer norms): Standard Adam update.

Key AdaMuon features: - Sign-stabilized orthogonal direction: Applies sign() before orthogonalization. - Per-element second-moment normalization using momentum coefficient. - RMS-aligned global scaling: 0.2 * sqrt(m * n) / norm.

Parameters:
paramsiterable

Iterable of parameters to optimize.

lrfloat

Learning rate with default 1e-3.

momentumfloat

Momentum coefficient for AdaMuon with default 0.95.

weight_decayfloat

Weight decay coefficient (applied only to >=2D params) with default 0.001.

ns_stepsint

Number of Newton-Schulz iterations with default 5.

adam_betastuple[float, float]

Adam beta coefficients with default (0.9, 0.95).

adam_epsfloat

Adam epsilon with default 1e-8.

nesterovbool

Whether to use Nesterov momentum for AdaMuon with default True.

lr_adjustfloat

Learning rate adjustment factor for Adam (1D params). - If lr_adjust <= 0: use match-RMS scaling for AdaMuon update,

scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.

  • If lr_adjust > 0: use rectangular correction for AdaMuon update, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate.

Default is 10.0 (Adam lr = lr/10).

lr_adjust_coefffloat

Coefficient for match-RMS scaling with default 0.2. Only effective when lr_adjust <= 0.

epsfloat

Epsilon for v_buffer sqrt and global scaling normalization with default 1e-8.

Examples

>>> optimizer = AdaMuonOptimizer(model.parameters(), lr=1e-3)
>>> for epoch in range(epochs):
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
step(closure: collections.abc.Callable[[], torch.Tensor] | None = None) torch.Tensor | None[source]#

Perform a single optimization step.

Parameters:
closurecallable(), optional

A closure that reevaluates the model and returns the loss.

Returns:
losstorch.Tensor, optional

The loss value if closure is provided.

class deepmd.pt.optimizer.HybridMuonOptimizer(params: collections.abc.Iterable[torch.Tensor] | collections.abc.Iterable[dict[str, Any]], lr: float = 0.0005, momentum: float = 0.95, weight_decay: float = 0.001, adam_betas: tuple[float, float] = (0.9, 0.95), lr_adjust: float = 0.0, lr_adjust_coeff: float = 0.18, muon_mode: str = 'slice', named_parameters: collections.abc.Iterable[tuple[str, torch.Tensor]] | None = None, enable_gram: bool = True, flash_muon: bool = True, magma_muon: bool = True, use_foreach: bool | None = None)[source]#

Bases: torch.optim.optimizer.Optimizer

HybridMuon optimizer with 1D Adam path and matrix Muon path.

This optimizer applies different update rules based on parameter dimensionality, parameter names, and muon_mode: - Parameters with final effective name segment containing bias

(case-insensitive), or starting with adam_ (case-insensitive): standard Adam update.

  • Parameters with final effective name segment starting with adamw_ (case-insensitive): Adam with decoupled weight decay (AdamW-style).

  • 1D parameters: standard Adam update.

  • Parameters are routed by effective shape (singleton dimensions removed).

  • muon_mode="2d": - effective rank 2 parameters use Muon. - effective rank >2 parameters use Adam.

  • muon_mode="flat": - effective rank >=2 parameters use flattened matrix-view Muon.

  • muon_mode="slice": - effective rank 2 parameters use Muon. - effective rank >=3 parameters apply Muon independently on each trailing

    (m, n) slice.

Naming convention for explicit Adam routing: - Parameters representing bias terms should include bias in their

final effective name segment (case-insensitive).

  • Parameters that are not semantic bias but should still use Adam should use an adam_ prefix in their final effective name segment (case-insensitive).

  • Parameters that should use Adam with decoupled weight decay should use an adamw_ prefix in their final effective name segment (case-insensitive).

This hybrid approach is effective because Muon’s orthogonalization is designed for weight matrices, while Adam is more suitable for biases and normalization params.

Parameters:
paramsiterable

Iterable of parameters to optimize.

lrfloat

Learning rate.

momentumfloat

Momentum coefficient for Muon with default 0.95.

weight_decayfloat

Weight decay coefficient with default 0.001. Applied to Muon-routed parameters and >=2D Adam-routed parameters with AdamW-style decoupled decay. Not applied to 1D Adam parameters.

adam_betastuple[float, float]

Adam beta coefficients with default (0.9, 0.95).

lr_adjustfloat

Learning rate adjustment mode for Muon scaling and Adam learning rate. - If lr_adjust <= 0: use match-RMS scaling for Muon,

scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.

  • If lr_adjust > 0: use rectangular correction for Muon, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.

Default is 0.0 (match-RMS scaling).

lr_adjust_coefffloat

Coefficient with default 0.18 for match-RMS scaling when lr_adjust <= 0: scale = lr_adjust_coeff * sqrt(max(m, n)). 0.18 is the value calibrated by DeepSeek-V4 so that Muon’s per-element update RMS matches AdamW’s typical RMS, enabling reuse of AdamW learning rates across both paths. The Moonlight reference uses 0.2; both are empirically viable.

muon_modestr

Muon routing mode with default "slice". - "2d": only 2D parameters are Muon candidates. - "flat": >=2D parameters use flattened matrix-view routing. - "slice": >=3D parameters use per-slice Muon routing on last two dims.

named_parametersiterable[tuple[str, torch.Tensor]] | None

Optional named parameter iterable used for name-based routing. Parameters with final effective name segment containing bias (case-insensitive), or starting with adam_ (case-insensitive), are forced to Adam (no weight decay). Parameters starting with adamw_ are forced to AdamW-style decoupled decay path.

enable_grambool

Enable the compiled Gram Newton-Schulz path for rectangular Muon matrices. Square matrices continue to use the current standard Newton-Schulz implementation. Default is True.

flash_muonbool

Enable triton-accelerated Newton-Schulz orthogonalization. Requires triton and CUDA. Falls back to PyTorch implementation when triton is unavailable or running on CPU. Ignored when enable_gram=True. Default is True.

magma_muonbool

Enable Magma-lite damping on Muon updates with default True. This computes momentum-gradient cosine alignment per Muon block, applies EMA smoothing, and rescales Muon updates in [0.1, 1.0]. Adam/AdamW paths are unchanged. Empirically beneficial for MLIP / SeZM training under heavy-tailed gradient noise from conservative-force (second-order) autograd.

Examples

>>> optimizer = HybridMuonOptimizer(model.parameters(), lr=5e-4)
>>> for epoch in range(epochs):
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
_param_name_map: dict[int, str]#
_routing_built = False#
_routing: list[dict[str, Any]] = []#
_use_flash = False#
_ns_buffers: dict[tuple[int, torch.device], tuple[torch.Tensor, torch.Tensor]]#
_gram_orthogonalizer: _GramNewtonSchulzOrthogonalizer | None = None#
_use_foreach#
set_param_names(named_parameters: collections.abc.Iterable[tuple[str, torch.Tensor]]) None[source]#

Set runtime-only parameter names used for name-based routing.

The mapping intentionally stays outside optimizer defaults and param_groups so optimizer checkpoints do not persist full (name, Parameter) tuples. Under ZeRO-1 this avoids gathering a duplicate model-sized object graph during consolidate_state_dict.

static _resolve_foreach(use_foreach: bool | None) bool[source]#

Resolve the use_foreach flag for torch._foreach_* kernels.

Foreach fuses per-parameter loops into single kernel launches, eliminating Python overhead. When use_foreach is None the default is True because plain torch.Tensor (single-GPU, DDP, ZeRO-1) always supports these ops; callers that hit DTensor dispatch errors under FSDP2 must pass use_foreach=False explicitly.

_compute_magma_scales_merged(bucket_entries: list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]], rows: int, cols: int) list[torch.Tensor][source]#

Compute Magma-lite scales for a merged bucket with variable batch_sizes.

Like _compute_magma_scales_for_bucket but handles entries whose batch_size may differ (produced by the merged-bucket strategy that keys on (rows, cols) instead of (batch_size, rows, cols)).

_compute_magma_scale(param: torch.Tensor, grad: torch.Tensor, momentum_buffer: torch.Tensor, batch_size: int, rows: int, cols: int) torch.Tensor[source]#

Compute Magma-lite Muon damping scales from momentum-gradient alignment.

Implements a stabilized version of Magma (Momentum-Aligned Gradient Masking) adapted for MLIP force-field training. Computes block-wise alignment scores between Muon momentum and current gradients, applies EMA smoothing, and rescales Muon updates to improve stability under heavy-tailed gradient noise.

Parameters:
paramtorch.Tensor

Parameter updated by Muon.

gradtorch.Tensor

Current gradient tensor with shape compatible with (batch_size, rows, cols).

momentum_buffertorch.Tensor

Muon momentum buffer (updated m_t) with same shape as grad.

batch_sizeint

Number of Muon blocks (1 for 2d/flat mode, >1 for slice mode).

rowsint

Matrix row count per block.

colsint

Matrix column count per block.

Returns:
torch.Tensor

Damping scales with shape (batch_size,) in [MAGMA_MIN_SCALE, 1.0].

Notes

For each Muon block b:

  1. Compute cosine similarity between momentum and gradient:

    cos(b) = <μ_t^(b), g_t^(b)> / (||μ_t^(b)|| * ||g_t^(b)||)

  2. Apply sigmoid with range stretching to [0, 1]:

    s_raw^(b) = (sigmoid(cos(b) / τ) - s_min) / (s_max - s_min)

    where τ=2.0, s_min=sigmoid(-1/τ), s_max=sigmoid(1/τ). This stretches the narrow sigmoid range [0.38, 0.62] to [0, 1].

  3. Apply EMA smoothing:

    s̃_t^(b) = a * s̃_{t-1}^(b) + (1-a) * s_raw^(b)

    where a=0.9 (MAGMA_EMA_DECAY).

  4. Map to damping scale in [s_min_scale, 1.0]:

    scale^(b) = s_min_scale + (1 - s_min_scale) * s̃_t^(b)

    where s_min_scale=0.1 (MAGMA_MIN_SCALE).

  5. Apply damping to Muon update:

    Δ̃^(b) = scale^(b) * Δ^(b) (soft scaling, no Bernoulli masking)

Key differences from the original Magma paper:

  • Sigmoid range stretching: Paper uses raw sigmoid with narrow range [0.38, 0.62]. We stretch to [0, 1] for better discrimination between aligned/misaligned blocks.

  • Soft scaling: Paper uses Bernoulli masking (50% skip probability). We use continuous soft scaling [0.1, 1.0] for stability in MLIP training.

  • Minimum scale: Paper allows scale=0 (complete skip). We enforce scale >= 0.1 to guarantee minimum learning rate.

_compute_magma_scales_for_bucket(bucket_entries: list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]], batch_size: int, rows: int, cols: int) list[torch.Tensor][source]#

Compute Magma-lite damping scales for one Muon bucket in a batched way.

Parameters:
bucket_entrieslist[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]

Bucket entries as (entry, update_tensor, grad, momentum_buffer).

batch_sizeint

Number of Muon blocks per parameter in this bucket.

rowsint

Matrix row count for this bucket.

colsint

Matrix column count for this bucket.

Returns:
list[torch.Tensor]

Magma scales for each bucket entry. Each tensor has shape (batch_size,).

_get_ns_buffers(M: int, device: torch.device) tuple[torch.Tensor, torch.Tensor][source]#

Get or lazily allocate pre-allocated buffers for flash Newton-Schulz.

Parameters:
Mint

Square buffer dimension (= min(rows, cols) of the update matrix).

devicetorch.device

Target CUDA device.

Returns:
tuple[torch.Tensor, torch.Tensor]

(buf1, buf2), each with shape (M, M) in bfloat16.

_get_gram_orthogonalizer() _GramNewtonSchulzOrthogonalizer[source]#

Lazily initialize the compiled Gram orthogonalizer.

Returns:
_GramNewtonSchulzOrthogonalizer

Shared Gram orthogonalizer instance for the optimizer.

_process_merged_gram_buckets(gram_buckets: dict[tuple[int, int, torch.device, torch.dtype], list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]], lr: float, lr_adjust: float, lr_adjust_coeff: float, magma_scales_map: dict[int, torch.Tensor]) None[source]#

Column-pad merge across rectangular buckets sharing the same min_dim.

Rectangular Muon matrices with the same min(rows, cols) can be fused into a single Gram Newton-Schulz call by zero-padding the column (large) dimension to the group maximum. This reduces the number of compiled Gram NS dispatches and improves GPU occupancy.

Mathematical equivalence proof for column-padding: Both Standard NS and Gram NS operate on the wide orientation X  (m x n), m <= n. The Gram matrix is R = X @ X^T  (m x m).

Let X_pad = [X | 0]  (m x (n+p)) where the last p columns are zero. Then:

  1. Frobenius norm is unchanged: ||X_pad||_F = ||X||_F because the zero columns contribute nothing.

  2. Gram matrix is unchanged: R_pad = X_pad @ X_pad^T = X @ X^T + 0 @ 0^T = R

  3. Since all NS iterations (both standard quintic and Gram/Polar- Express) depend only on R (which is m x m regardless of n), every intermediate Q_k is identical.

  4. The restart step X_new = Q @ X_pad = [Q @ X | 0] also preserves the invariant R_new = Q @ R @ Q^T, so subsequent iterations remain identical.

  5. The final output is Q_last @ X_pad = [Q_last @ X | 0]. Truncating to the first n columns exactly recovers the unpadded result.

Constraint: Only the column (large) dimension may be padded. Padding rows would change the size of R and break equivalence.

Per-entry scale and Magma damping are applied after unpadding, since different original shapes have different max(rows, cols).

_build_param_routing() None[source]#

Classify parameters into Muon, Adam, and AdamW routes (static routing).

Routing logic: - name-based adam_ prefix or contains bias → Adam (no decay) - name-based adamw_ prefix → AdamW (decoupled weight decay) - effective shape rank <2 → Adam (no decay) - non-matrix effective shape for current muon_mode → AdamW (decoupled) - remaining eligible matrix params → Muon path

_adam_update_moments(exp_avgs: list[torch.Tensor], exp_avg_sqs: list[torch.Tensor], grads_fp32: list[torch.Tensor], beta1: float, beta2: float) None[source]#

Update Adam first/second moment estimates, foreach-accelerated when safe.

exp_avg = beta1 * exp_avg + (1 - beta1) * grad exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2

_weight_decay_inplace(params: list[torch.Tensor], factor: float) None[source]#

Apply multiplicative weight decay, foreach-accelerated when safe.

step(closure: collections.abc.Callable[[], torch.Tensor] | None = None) torch.Tensor | None[source]#

Perform a single optimization step.

Parameters:
closurecallable(), optional

A closure that reevaluates the model and returns the loss.

Returns:
torch.Tensor | None

The loss value if closure is provided, otherwise None.

class deepmd.pt.optimizer.KFOptimizerWrapper(model: torch.nn.Module, optimizer: torch.optim.optimizer.Optimizer, atoms_selected: int, atoms_per_group: int, is_distributed: bool = False)[source]#
model#
optimizer#
atoms_selected#
atoms_per_group#
is_distributed = False#
update_energy(inputs: dict, Etot_label: torch.Tensor, update_prefactor: float = 1) None[source]#
update_force(inputs: dict, Force_label: torch.Tensor, update_prefactor: float = 1) None[source]#
update_denoise_coord(inputs: dict, clean_coord: torch.Tensor, update_prefactor: float = 1, mask_loss_coord: bool = True, coord_mask: torch.Tensor = None) None[source]#
__sample(atoms_selected: int, atoms_per_group: int, natoms: int) numpy.ndarray[source]#
class deepmd.pt.optimizer.LKFOptimizer(params: Any, kalman_lambda: float = 0.98, kalman_nue: float = 0.9987, block_size: int = 5120)[source]#

Bases: torch.optim.optimizer.Optimizer

_params#
_state#
dist_init#
rank#
dindex = []#
remainder = 0#
__init_P() None[source]#
__get_blocksize() int[source]#
__get_nue() float[source]#
__split_weights(weight: torch.Tensor) list[torch.Tensor][source]#
__update(H: torch.Tensor, error: torch.Tensor, weights: torch.Tensor) None[source]#
set_grad_prefactor(grad_prefactor: float) None[source]#
step(error: torch.Tensor) None[source]#
get_device_id(index: int) int | None[source]#