# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)
import torch
from deepmd.dpmodel import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.model.task.fitting import (
GeneralFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
[docs]
log = logging.getLogger(__name__)
@GeneralFitting.register("dipole")
[docs]
class DipoleFittingNet(GeneralFitting):
"""Construct a dipole fitting net.
Parameters
----------
ntypes : int
Element count.
dim_descrpt : int
Embedding width per atom.
embedding_width : int
The dimension of rotation matrix, m1.
neuron : List[int]
Number of neurons in each hidden layers of the fitting net.
resnet_dt : bool
Using time-step in the ResNet construction.
numb_fparam : int
Number of frame parameters.
numb_aparam : int
Number of atomic parameters.
activation_function : str
Activation function.
precision : str
Numerical precision.
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
rcond : float, optional
The condition number for the regression of atomic energy.
seed : int, optional
Random seed.
r_differentiable
If the variable is differentiated with respect to coordinates of atoms.
Only reduciable variable are differentiable.
c_differentiable
If the variable is differentiated with respect to the cell tensor (pbc case).
Only reduciable variable are differentiable.
"""
def __init__(
self,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
neuron: List[int] = [128, 128, 128],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
r_differentiable: bool = True,
c_differentiable: bool = True,
**kwargs,
):
self.embedding_width = embedding_width
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name="dipole",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
**kwargs,
)
self.old_impl = False # this only supports the new implementation.
[docs]
def _net_out_dim(self):
"""Set the FittingNet output dim."""
return self.embedding_width
[docs]
def serialize(self) -> dict:
data = super().serialize()
data["type"] = "dipole"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
data["c_differentiable"] = self.c_differentiable
return data
@classmethod
[docs]
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name", None)
return super().deserialize(data)
[docs]
def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
self.var_name,
[3],
reduciable=True,
r_differentiable=self.r_differentiable,
c_differentiable=self.c_differentiable,
),
]
)
[docs]
def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
pass
[docs]
def forward(
self,
descriptor: torch.Tensor,
atype: torch.Tensor,
gr: Optional[torch.Tensor] = None,
g2: Optional[torch.Tensor] = None,
h2: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# (nframes, nloc, m1)
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, -1, 3)
# (nframes, nloc, 3)
out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3)
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
# make jit happy with torch 2.0.0
[docs]
exclude_types: List[int]