deepmd.pt_expt.model.spin_ener_model#

Classes#

SpinEnergyModel

A spin model for energy.

Module Contents#

class deepmd.pt_expt.model.spin_ener_model.SpinEnergyModel(backbone_model: deepmd.dpmodel.atomic_model.dp_atomic_model.DPAtomicModel, spin: deepmd.utils.spin.Spin)[source]#

Bases: deepmd.pt_expt.model.spin_model.SpinModel

A spin model for energy.

model_type = 'ener'[source]#
translated_output_def() dict[str, Any][source]#

Get the translated output definition.

Maps internal output names to user-facing names, e.g. energy -> atom_energy, energy_redu -> energy, energy_derv_r -> force, energy_derv_r_mag -> force_mag.

forward(coord: torch.Tensor, atype: torch.Tensor, spin: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None) dict[str, torch.Tensor][source]#
forward_lower(extended_coord: torch.Tensor, extended_atype: torch.Tensor, extended_spin: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None) dict[str, torch.Tensor][source]#
forward_lower_exportable(extended_coord: torch.Tensor, extended_atype: torch.Tensor, extended_spin: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, **make_fx_kwargs: Any) torch.nn.Module[source]#

Trace forward_lower into an exportable module.

Delegates to forward_common_lower_exportable for tracing, then translates the internal keys to the forward_lower convention.

Parameters:
extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam, do_atomic_virial

Sample inputs with representative shapes (used for tracing).

**make_fx_kwargs

Extra keyword arguments forwarded to make_fx (e.g. tracing_mode="symbolic").

Returns:
torch.nn.Module

A traced module whose forward accepts (extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam) and returns a dict with the same keys as forward_lower.