Source code for deepmd.pt.model.task.dipole

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from collections.abc import (
    Callable,
)
from typing import (
    Any,
)

import torch

from deepmd.dpmodel import (
    FittingOutputDef,
    OutputVariableDef,
)
from deepmd.pt.model.task.fitting import (
    GeneralFitting,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.env import (
    DEFAULT_PRECISION,
)
from deepmd.utils.path import (
    DPPath,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

[docs] log = logging.getLogger(__name__)
@GeneralFitting.register("dipole")
[docs] class DipoleFittingNet(GeneralFitting): """Construct a dipole fitting net. Parameters ---------- ntypes : int Element count. dim_descrpt : int Embedding width per atom. embedding_width : int The dimension of rotation matrix, m1. neuron : list[int] Number of neurons in each hidden layers of the fitting net. resnet_dt : bool Using time-step in the ResNet construction. numb_fparam : int Number of frame parameters. numb_aparam : int Number of atomic parameters. dim_case_embd : int Dimension of case specific embedding. activation_function : str Activation function. precision : str Numerical precision. mixed_types : bool If true, use a uniform fitting net for all atom types, otherwise use different fitting nets for different atom types. rcond : float, optional The condition number for the regression of atomic energy. seed : int, optional Random seed. r_differentiable If the variable is differentiated with respect to coordinates of atoms. Only reducible variable are differentiable. c_differentiable If the variable is differentiated with respect to the cell tensor (pbc case). Only reducible variable are differentiable. type_map: list[str], Optional A list of strings. Give the name to each type of atoms. default_fparam: list[float], optional The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net. """ def __init__( self, ntypes: int, dim_descrpt: int, embedding_width: int, neuron: list[int] = [128, 128, 128], resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, rcond: float | None = None, seed: int | list[int] | None = None, exclude_types: list[int] = [], r_differentiable: bool = True, c_differentiable: bool = True, type_map: list[str] | None = None, default_fparam: list | None = None, **kwargs: Any, ) -> None:
[docs] self.embedding_width = embedding_width
[docs] self.r_differentiable = r_differentiable
[docs] self.c_differentiable = c_differentiable
super().__init__( var_name="dipole", ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, rcond=rcond, seed=seed, exclude_types=exclude_types, type_map=type_map, default_fparam=default_fparam, **kwargs, )
[docs] def _net_out_dim(self) -> int: """Set the FittingNet output dim.""" return self.embedding_width
[docs] def serialize(self) -> dict: data = super().serialize() data["type"] = "dipole" data["embedding_width"] = self.embedding_width data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data
@classmethod
[docs] def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() check_version_compatibility(data.pop("@version", 1), 4, 1) data.pop("var_name", None) return super().deserialize(data)
[docs] def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ OutputVariableDef( self.var_name, [3], reducible=True, r_differentiable=self.r_differentiable, c_differentiable=self.c_differentiable, ), ] )
[docs] def compute_output_stats( self, merged: Callable[[], list[dict]] | list[dict], stat_file_path: DPPath | None = None, ) -> None: """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. Parameters ---------- merged : Union[Callable[[], list[dict]], list[dict]] - list[dict]: A list of data samples from various data systems. Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` originating from the `i`-th data system. - Callable[[], list[dict]]: A lazy function that returns data samples in the above format only when needed. Since the sampling process can be slow and memory-intensive, the lazy function helps by only sampling once. stat_file_path : Optional[DPPath] The path to the stat file. """ pass
[docs] def forward( self, descriptor: torch.Tensor, atype: torch.Tensor, gr: torch.Tensor | None = None, g2: torch.Tensor | None = None, h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # cast the input to internal precsion gr = gr.to(self.prec) # (nframes, nloc, m1) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
# make jit happy with torch 2.0.0
[docs] exclude_types: list[int]