deepmd.dpmodel.descriptor.dpa4_nn.lora#

LoRA low-rank fine-tuning support for DPA4/SeZM.

This module adds two things:

  • LoRASO3 and LoRASO2 subclasses that wrap the corresponding base equivariant linear operators (SO3Linear / SO2Linear). Each one freezes the pre-trained weights and registers rank-R adapter parameters A/B whose shapes share the base’s batch layout (per-l for SO(3), per-|m|-group for SO(2)). The LoRA delta is folded into the effective weight before the single large einsum that already exists in the base module; forward FLOPs are therefore identical to the base, and the overhead comes only from an O(R) weight-side matmul that does not depend on the number of edges or nodes.

  • apply_lora_to_sezm, merge_lora_into_base and a few helpers that drive the fine-tune policy (which submodules stay trainable, which ones remain frozen) and the merged-checkpoint export used by Trainer.save_model_merged.

Naming convention: the LoRA parameter names – A_by_l, B_by_l, A_m0, B_m0, A_m, B_m – intentionally do not start with adam_ / adamw_ and do not contain bias. HybridMuon.get_adam_route therefore classifies them as muon and, because the tensors have the same rank structure as the corresponding base weights, the slice-mode matrix view gives per-l / per-|m|-group Newton-Schulz updates that match the base training recipe.

This module is the dpmodel (array-API) port of deepmd.pt.model.descriptor.sezm_nn.lora.

Attributes#

Classes#

LoRASO3

Per-l ELoRA adapter for SO3Linear.

LoRASO2

Per-|m|-group LoRA adapter for SO2Linear.

Functions#

_iter_named_modules(...)

Yield (dotted_name, module) for root and every nested NativeOP.

_iter_named_parameters(...)

Yield (dotted_name, owner, array) for every numpy-array parameter.

_module_state_dict(→ dict[str, numpy.ndarray])

Flat dotted {name: array} dict over the whole module tree.

_leaf_name(→ str)

Return the trailing non-numeric segment of a parameter name.

_get_submodule_or_none(→ Any)

_clear_sezm_compile_cache(→ None)

No-op retained for parity with the PyTorch backend.

_swap_submodule(→ None)

Replace parent.attr with new_module.

has_lora(→ bool)

Return True iff any submodule is a LoRA adapter.

apply_lora_to_sezm(→ deepmd.dpmodel.NativeOP)

Inject LoRA adapters into every SO3Linear / SO2Linear of a SeZM

fold_lora_state_dict_keys(→ None)

Fold LoRA adapter keys into base weight keys in state_dict (in-place).

build_merged_state_dict(→ dict[str, numpy.ndarray])

Produce a plain (LoRA-free) state dict from a LoRA-augmented module.

strip_lora_from_extra_state(→ dict[str, Any])

Drop any lora entry from _extra_state["model_params"].

merge_lora_into_base(→ deepmd.dpmodel.NativeOP)

Destructively replace every LoRASO3 / LoRASO2 with its merged

Module Contents#

class deepmd.dpmodel.descriptor.dpa4_nn.lora.LoRASO3(*, lmax: int, in_channels: int, out_channels: int, n_focus: int = 1, precision: str = DEFAULT_PRECISION, mlp_bias: bool = False, trainable: bool = False, seed: int | list[int] | None = None, lora_rank: int, lora_alpha: float | None = None)[source]#

Bases: deepmd.dpmodel.descriptor.dpa4_nn.so3.SO3Linear

Per-l ELoRA adapter for SO3Linear.

The pre-trained weight self.weight ((lmax+1, C_in, F*C_out)) is frozen. Two new 3D parameters A_by_l ((lmax+1, rank, C_in)) and B_by_l ((lmax+1, F*C_out, rank)) share the same lmax+1 batch axis as the base so that muon_mode="slice" updates every l-block independently. SO(3) equivariance is preserved because the per-l delta only rotates within each l-block (no cross-l mixing).

Parameters:
lmax, in_channels, out_channels, n_focus, precision, mlp_bias, trainable, seed

Forwarded to SO3Linear to build the frozen base weight.

lora_rank

LoRA rank. Must satisfy lora_rank >= 1.

lora_alpha

Scaling numerator; the effective scaling is lora_alpha / lora_rank. None defaults to lora_alpha = lora_rank (scaling 1.0).

trainable = False[source]#
lora_rank[source]#
lora_alpha[source]#
scaling[source]#
lora_scaling[source]#
A_by_l[source]#
B_by_l[source]#
_compute_delta_weight(xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#

Return ΔW with shape (lmax+1, C_in, F*C_out).

call(x: deepmd.dpmodel.array_api.Array) deepmd.dpmodel.array_api.Array[source]#
Parameters:
x

Input features with shape (N, D, F, C_in) where D=(lmax+1)^2.

Returns:
Array

Output features with shape (N, D, F, C_out).

merge_into_base() deepmd.dpmodel.descriptor.dpa4_nn.so3.SO3Linear[source]#

Build a plain SO3Linear whose weight has absorbed the LoRA delta.

serialize() dict[str, Any][source]#

Serialize the LoRASO3 to a dict.

classmethod deserialize(data: dict[str, Any]) LoRASO3[source]#

Deserialize a LoRASO3 from a dict.

class deepmd.dpmodel.descriptor.dpa4_nn.lora.LoRASO2(*, lmax: int, mmax: int | None = None, in_channels: int, out_channels: int, n_focus: int = 1, precision: str = DEFAULT_PRECISION, mlp_bias: bool = False, seed: int | list[int] | None = None, trainable: bool = False, lora_rank: int, lora_alpha: float | None = None)[source]#

Bases: deepmd.dpmodel.descriptor.dpa4_nn.so2.SO2Linear

Per-|m|-group LoRA adapter for SO2Linear.

weight_m0 ((num_in_m0, F*num_out_m0)) and each weight_m[i] ((num_in_m, F*2*num_out_m)) get an independent 2D LoRA pair A/B. SO(2) equivariance is preserved because the |m|>0 2x2 complex block [[W_u, -W_v], [W_v, W_u]] stays intact when ΔW_m is absorbed into the concatenated [W_u | W_v] layout before _build_so2_weight splits it (the shared input basis A splits naturally into ΔW_u = B_u A and ΔW_v = B_v A).

The base call logic is inherited unchanged; only _build_so2_weight is overridden to fold the LoRA delta into each base block prior to assembling the block-diagonal weight. The ΔW_m construction does not depend on the edge count E, so the forward FLOPs remain identical to the base.

Parameters:
lmax, mmax, in_channels, out_channels, n_focus, precision, mlp_bias, trainable, seed

Forwarded to SO2Linear to build the frozen base weights.

lora_rank

LoRA rank.

lora_alpha

Scaling numerator; scaling is lora_alpha / lora_rank. None defaults to lora_alpha = lora_rank (scaling 1.0).

trainable = False[source]#
lora_rank[source]#
lora_alpha[source]#
scaling[source]#
lora_scaling[source]#
A_m0[source]#
B_m0[source]#
A_m: list[numpy.ndarray] = [][source]#
B_m: list[numpy.ndarray] = [][source]#
_compute_delta_m0(xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#

Return ΔW_m0 with shape (num_in_m0, F*num_out_m0).

_compute_delta_m(m_idx: int, xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#

Return ΔW_m[m_idx] with the same shape as weight_m[m_idx].

_build_so2_weight(xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#

Assemble the block-diagonal weight with LoRA delta folded in.

merge_into_base() deepmd.dpmodel.descriptor.dpa4_nn.so2.SO2Linear[source]#

Build a plain SO2Linear whose weights have absorbed every LoRA delta.

serialize() dict[str, Any][source]#
classmethod deserialize(data: dict[str, Any]) LoRASO2[source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._UNFREEZE_LEAF_NAMES: frozenset[str][source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._OVERRIDE_FREEZE_LEAF_NAMES: frozenset[str][source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._UNFREEZE_SUBMODULE_PATHS: tuple[str, Ellipsis] = ('atomic_model.fitting_net', 'atomic_model.dens_fitting_net',...[source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._UNFREEZE_PER_BLOCK_SUBPATHS: tuple[str, Ellipsis] = ('full_attn_res_so2', 'full_attn_res_ffns', 'block_attn_res_so2', 'block_attn_res_ffns',...[source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._BLOCKS_PATH: str = 'atomic_model.descriptor.blocks'[source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._iter_named_modules(root: deepmd.dpmodel.NativeOP, prefix: str = '', memo: set[int] | None = None) collections.abc.Iterator[tuple[str, deepmd.dpmodel.NativeOP]][source]#

Yield (dotted_name, module) for root and every nested NativeOP.

root is yielded first under prefix, then the walk descends into every attribute value that is a NativeOP and into every NativeOP element of a list/tuple, building dotted paths (attr and attr.{i}). A shared-module memo de-duplicates repeated references.

deepmd.dpmodel.descriptor.dpa4_nn.lora._iter_named_parameters(root: deepmd.dpmodel.NativeOP) collections.abc.Iterator[tuple[str, deepmd.dpmodel.NativeOP, numpy.ndarray]][source]#

Yield (dotted_name, owner, array) for every numpy-array parameter.

A dpmodel “parameter” is a numpy array stored as a module attribute (or a numpy element of a list/tuple attribute, the equivalent of an nn.ParameterList). owner is the module holding the array; because the dpmodel tracks trainability per module (module.trainable) rather than per tensor, callers toggle owner.trainable where the PyTorch code toggles param.requires_grad.

deepmd.dpmodel.descriptor.dpa4_nn.lora._module_state_dict(root: deepmd.dpmodel.NativeOP) dict[str, numpy.ndarray][source]#

Flat dotted {name: array} dict over the whole module tree.

deepmd.dpmodel.descriptor.dpa4_nn.lora._leaf_name(param_name: str) str[source]#

Return the trailing non-numeric segment of a parameter name.

nn.ParameterList children show up as foo.0, foo.1, …; get_adam_route strips those numeric indices before routing, so this helper keeps the policy in sync.

deepmd.dpmodel.descriptor.dpa4_nn.lora._get_submodule_or_none(root: deepmd.dpmodel.NativeOP, path: str) Any[source]#
deepmd.dpmodel.descriptor.dpa4_nn.lora._clear_sezm_compile_cache(model: deepmd.dpmodel.NativeOP) None[source]#

No-op retained for parity with the PyTorch backend.

In PyTorch, LoRA injection or merge replaces submodules and therefore invalidates any torch.compile / inductor callable captured on the module graph, which must be cleared before the next forward. The dpmodel (array-API) backend compiles nothing, so there is no cache to clear and this function intentionally does nothing.

deepmd.dpmodel.descriptor.dpa4_nn.lora._swap_submodule(parent: Any, attr: str, new_module: deepmd.dpmodel.NativeOP) None[source]#

Replace parent.attr with new_module.

Numeric attribute names address list/tuple children (the dpmodel analogue of nn.ModuleList / nn.ParameterList elements) and are assigned by index; every other name is a plain attribute assignment.

deepmd.dpmodel.descriptor.dpa4_nn.lora.has_lora(module: deepmd.dpmodel.NativeOP) bool[source]#

Return True iff any submodule is a LoRA adapter.

deepmd.dpmodel.descriptor.dpa4_nn.lora.apply_lora_to_sezm(model: deepmd.dpmodel.NativeOP, *, rank: int, alpha: float | None = None) deepmd.dpmodel.NativeOP[source]#

Inject LoRA adapters into every SO3Linear / SO2Linear of a SeZM model and apply the SeZM fine-tune freeze/unfreeze policy in place.

This function is idempotent-safe: the type(mod) is SO3Linear (exact type) test prevents re-wrapping a LoRASO3 that is already present.

Parameters:
model

A SeZMModel instance (or any NativeOP containing SeZM SO3Linear / SO2Linear submodules).

rank

LoRA rank applied uniformly to every adapter.

alpha

LoRA scaling numerator; scaling is alpha / rank. None defaults to alpha = rank (scaling 1.0).

Returns:
NativeOP

The same model after injection (returned for chaining).

deepmd.dpmodel.descriptor.dpa4_nn.lora.fold_lora_state_dict_keys(state_dict: dict[str, numpy.ndarray], prefix: str) None[source]#

Fold LoRA adapter keys into base weight keys in state_dict (in-place).

Scans for SO3-style A_by_l/B_by_l pairs and SO2-style A_m0/B_m0/A_m.*/B_m.* groups under prefix. For each pair whose corresponding base weight key also exists, the delta einsum(B, A) * scaling is added to the weight and the adapter keys are popped. lora_scaling is read from state_dict when present; otherwise 1.0 is assumed (the default when alpha == rank).

Called by DescrptSeZM._load_from_state_dict so that a LoRA-trained checkpoint can be loaded into a plain (non-LoRA) descriptor transparently.

Parameters:
state_dict

Flat state dict to mutate in place.

prefix

Key prefix that scopes the scan (e.g. "model.Default.atomic_model.descriptor.").

deepmd.dpmodel.descriptor.dpa4_nn.lora.build_merged_state_dict(module: deepmd.dpmodel.NativeOP, state_dict: dict[str, numpy.ndarray] | None = None, *, prefix: str = '') dict[str, numpy.ndarray][source]#

Produce a plain (LoRA-free) state dict from a LoRA-augmented module.

Walks module.named_modules() and, for every LoRASO3 / LoRASO2 submodule, folds ΔW = BA·scaling into the base weight key and removes the A/B keys. The returned dict has the same key set as a same-topology SeZM that has never been LoRA-wrapped, and is suitable for loading into a plain SeZM model with strict=True.

Non-destructive: when state_dict is None a deep copy of module.state_dict() is taken; when the caller provides a state_dict it is assumed to already be a detached copy (e.g. the full-gathered state dict from FSDP2) and is mutated in place for efficiency.

Parameters:
module

The LoRA-augmented module tree. Only used for structural information (LoRA submodule prefixes, scaling, weight_m length); its parameters are not modified.

state_dict

Optional pre-collected state dict (e.g. gathered from FSDP2). If None, deepcopy(module.state_dict()) is used.

prefix

Prefix to prepend to every LoRA submodule name when looking keys up in state_dict. Use this when the caller has state keyed under an outer wrapper (for example "model.Default.").

Returns:
dict

Flat state dict with LoRA adapters folded into base weights.

deepmd.dpmodel.descriptor.dpa4_nn.lora.strip_lora_from_extra_state(extra_state: dict[str, Any]) dict[str, Any][source]#

Drop any lora entry from _extra_state["model_params"].

Handles both single-task (model_params is the model config) and multi-task (model_params["model_dict"][<branch>] is each branch’s config). Returns a deep-copied dict; the input is not mutated.

deepmd.dpmodel.descriptor.dpa4_nn.lora.merge_lora_into_base(model: deepmd.dpmodel.NativeOP) deepmd.dpmodel.NativeOP[source]#

Destructively replace every LoRASO3 / LoRASO2 with its merged plain base module.

After this call the model no longer contains LoRA submodules: the optimizer, EMA state, and any compiled callables that reference the old submodules become invalid. Prefer build_merged_state_dict() for non-destructive checkpoint export during or after training; this function is primarily useful in tests and offline scripts.