deepmd.pt_expt.model.spin_ener_model#
Classes#
A spin model for energy. |
Module Contents#
- class deepmd.pt_expt.model.spin_ener_model.SpinEnergyModel(backbone_model: deepmd.dpmodel.atomic_model.dp_atomic_model.DPAtomicModel, spin: deepmd.utils.spin.Spin)[source]#
Bases:
deepmd.pt_expt.model.spin_model.SpinModelA spin model for energy.
- translated_output_def() dict[str, Any][source]#
Get the translated output definition.
Maps internal output names to user-facing names, e.g.
energy->atom_energy,energy_redu->energy,energy_derv_r->force,energy_derv_r_mag->force_mag.
- forward(coord: torch.Tensor, atype: torch.Tensor, spin: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None) dict[str, torch.Tensor][source]#
- forward_lower(extended_coord: torch.Tensor, extended_atype: torch.Tensor, extended_spin: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None) dict[str, torch.Tensor][source]#
- forward_lower_exportable(extended_coord: torch.Tensor, extended_atype: torch.Tensor, extended_spin: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, **make_fx_kwargs: Any) torch.nn.Module[source]#
Trace
forward_lowerinto an exportable module.Delegates to
forward_common_lower_exportablefor tracing, then translates the internal keys to theforward_lowerconvention.- Parameters:
- extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam, do_atomic_virial
Sample inputs with representative shapes (used for tracing).
- **make_fx_kwargs
Extra keyword arguments forwarded to
make_fx(e.g.tracing_mode="symbolic").
- Returns:
torch.nn.ModuleA traced module whose
forwardaccepts(extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam)and returns a dict with the same keys asforward_lower.