# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
Dict,
List,
Optional,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)
[docs]
def make_base_atomic_model(
t_tensor,
fwd_method_name: str = "forward_atomic",
):
"""Make the base class for the atomic model.
Parameters
----------
t_tensor
The type of the tensor. used in the type hint.
fwd_method_name
Name of the forward method. For dpmodels, it should be "call".
For torch models, it should be "forward".
"""
class BAM(ABC, PluginVariant, make_plugin_registry("atomic model")):
"""Base Atomic Model provides the interfaces of an atomic model."""
@abstractmethod
def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of developer implemented atomic models."""
pass
def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model.
By default it is the same as FittingOutputDef, but it
allows model level wrapper of the output defined by the developer.
"""
return self.fitting_output_def()
@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
pass
@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""
pass
def get_ntypes(self) -> int:
"""Get the number of atom types."""
return len(self.get_type_map())
@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
pass
def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())
def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.get_nsel()
@abstractmethod
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
@abstractmethod
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""
@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
If False, the shape is (nframes, nloc, ndim).
"""
@abstractmethod
def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
2. uses a neighbor list that does not distinguish different atomic types.
If false, the model
1. assumes total number of atoms of each atom type aligned across frames;
2. uses a neighbor list that distinguishes different atomic types.
"""
pass
@abstractmethod
def fwd(
self,
extended_coord: t_tensor,
extended_atype: t_tensor,
nlist: t_tensor,
mapping: Optional[t_tensor] = None,
fparam: Optional[t_tensor] = None,
aparam: Optional[t_tensor] = None,
) -> Dict[str, t_tensor]:
pass
@abstractmethod
def serialize(self) -> dict:
pass
@classmethod
@abstractmethod
def deserialize(cls, data: dict):
pass
def make_atom_mask(
self,
atype: t_tensor,
) -> t_tensor:
"""The atoms with type < 0 are treated as virutal atoms,
which serves as place-holders for multi-frame calculations
with different number of atoms in different frames.
Parameters
----------
atype
Atom types. >= 0 for real atoms <0 for virtual atoms.
Returns
-------
mask
True for real atoms and False for virutal atoms.
"""
# supposed to be supported by all backends
return atype >= 0
def do_grad_r(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is r_differentiable.
if var_name is None, returns if any of the variable is r_differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv, "r"))
return any(require)
else:
return self.do_grad_(var_name, "r")
def do_grad_c(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is c_differentiable.
if var_name is None, returns if any of the variable is c_differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv, "c"))
return any(require)
else:
return self.do_grad_(var_name, "c")
def do_grad_(self, var_name: str, base: str) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
assert base in ["c", "r"]
if base == "c":
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable
setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")
return BAM