# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
Type,
)
import torch
from deepmd.dpmodel import (
ModelOutputDef,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableCategory,
OutputVariableOperation,
check_operation_applied,
)
from deepmd.pt.model.atomic_model.base_atomic_model import (
BaseAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.model.model.transform_output import (
communicate_extended_output,
fit_output_to_model_output,
)
from deepmd.pt.utils.env import (
GLOBAL_PT_ENER_FLOAT_PRECISION,
GLOBAL_PT_FLOAT_PRECISION,
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
nlist_distinguish_types,
)
from deepmd.utils.path import (
DPPath,
)
[docs]
def make_model(T_AtomicModel: Type[BaseAtomicModel]):
"""Make a model as a derived class of an atomic model.
The model provide two interfaces.
1. the `forward_common_lower`, that takes extended coordinates, atyps and neighbor list,
and outputs the atomic and property and derivatives (if required) on the extended region.
2. the `forward_common`, that takes coordinates, atypes and cell and predicts
the atomic and reduced property, and derivatives (if required) on the local region.
Parameters
----------
T_AtomicModel
The atomic model.
Returns
-------
CM
The model.
"""
class CM(BaseModel):
def __init__(
self,
*args,
# underscore to prevent conflict with normal inputs
atomic_model_: Optional[T_AtomicModel] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
if atomic_model_ is not None:
self.atomic_model: T_AtomicModel = atomic_model_
else:
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION
self.global_pt_ener_float_precision = GLOBAL_PT_ENER_FLOAT_PRECISION
def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.atomic_output_def())
@torch.jit.export
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""
output_def = self.model_output_def()
var_defs = output_def.var_defs
# jit: Comprehension ifs are not supported yet
# type hint is critical for JIT
vars: List[str] = []
for kk, vv in var_defs.items():
# .value is critical for JIT
if vv.category == OutputVariableCategory.OUT.value:
vars.append(kk)
return vars
# cannot use the name forward. torch script does not work
def forward_common(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return model prediction.
Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Returns
-------
ret_dict
The result dict of type Dict[str,torch.Tensor].
The keys are defined by the `ModelOutputDef`.
"""
cc, bb, fp, ap, input_prec = self.input_type_cast(
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
cc,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=bb,
)
model_predict_lower = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
do_atomic_virial=do_atomic_virial,
fparam=fp,
aparam=ap,
)
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()
def change_out_bias(
self,
merged,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias of atomic model according to the input data and the pretrained model.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
self.atomic_model.change_out_bias(
merged,
bias_adjust_mode=bias_adjust_mode,
)
def forward_common_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
as input, and returns the predictions on the extended region.
The predictions are not reduced.
Parameters
----------
extended_coord
coodinates in extended region. nf x (nall x 3)
extended_atype
atomic type in extended region. nf x nall
nlist
neighbor list. nf x nloc x nsel.
mapping
mapps the extended indices to local indices. nf x nall.
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
result_dict
the result dict, defined by the `FittingOutputDef`.
"""
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.view(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
comm_dict=comm_dict,
)
model_predict = fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
def input_type_cast(
self,
coord: torch.Tensor,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
str,
]:
"""Cast the input data to global float type."""
input_prec = self.reverse_precision_dict[coord.dtype]
###
### type checking would not pass jit, convert to coord prec anyway
###
# for vv, kk in zip([fparam, aparam], ["frame", "atomic"]):
# if vv is not None and self.reverse_precision_dict[vv.dtype] != input_prec:
# log.warning(
# f"type of {kk} parameter {self.reverse_precision_dict[vv.dtype]}"
# " does not match"
# f" that of the coordinate {input_prec}"
# )
_lst: List[Optional[torch.Tensor]] = [
vv.to(coord.dtype) if vv is not None else None
for vv in [box, fparam, aparam]
]
box, fparam, aparam = _lst
if (
input_prec
== self.reverse_precision_dict[self.global_pt_float_precision]
):
return coord, box, fparam, aparam, input_prec
else:
pp = self.global_pt_float_precision
return (
coord.to(pp),
box.to(pp) if box is not None else None,
fparam.to(pp) if fparam is not None else None,
aparam.to(pp) if aparam is not None else None,
input_prec,
)
def output_type_cast(
self,
model_ret: Dict[str, torch.Tensor],
input_prec: str,
) -> Dict[str, torch.Tensor]:
"""Convert the model output to the input prec."""
do_cast = (
input_prec
!= self.reverse_precision_dict[self.global_pt_float_precision]
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
for kk in odef.keys():
if kk not in model_ret.keys():
# do not return energy_derv_c if not do_atomic_virial
continue
if check_operation_applied(odef[kk], OutputVariableOperation.REDU):
model_ret[kk] = (
model_ret[kk].to(self.global_pt_ener_float_precision)
if model_ret[kk] is not None
else None
)
elif do_cast:
model_ret[kk] = (
model_ret[kk].to(pp) if model_ret[kk] is not None else None
)
return model_ret
def format_nlist(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
):
"""Format the neighbor list.
1. If the number of neighbors in the `nlist` is equal to sum(self.sel),
it does nothong
2. If the number of neighbors in the `nlist` is smaller than sum(self.sel),
the `nlist` is pad with -1.
3. If the number of neighbors in the `nlist` is larger than sum(self.sel),
the nearest sum(sel) neighbors will be preseved.
Known limitations:
In the case of not self.mixed_types, the nlist is always formatted.
May have side effact on the efficiency.
Parameters
----------
extended_coord
coodinates in extended region. nf x nall x 3
extended_atype
atomic type in extended region. nf x nall
nlist
neighbor list. nf x nloc x nsel
Returns
-------
formated_nlist
the formated nlist.
"""
mixed_types = self.mixed_types()
nlist = self._format_nlist(extended_coord, nlist, sum(self.get_sel()))
if not mixed_types:
nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel())
return nlist
def _format_nlist(
self,
extended_coord: torch.Tensor,
nlist: torch.Tensor,
nnei: int,
):
n_nf, n_nloc, n_nnei = nlist.shape
# nf x nall x 3
extended_coord = extended_coord.view([n_nf, -1, 3])
rcut = self.get_rcut()
if n_nnei < nnei:
nlist = torch.cat(
[
nlist,
-1
* torch.ones(
[n_nf, n_nloc, nnei - n_nnei],
dtype=nlist.dtype,
device=nlist.device,
),
],
dim=-1,
)
elif n_nnei > nnei:
m_real_nei = nlist >= 0
nlist = torch.where(m_real_nei, nlist, 0)
# nf x nloc x 3
coord0 = extended_coord[:, :n_nloc, :]
# nf x (nloc x nnei) x 3
index = nlist.view(n_nf, n_nloc * n_nnei, 1).expand(-1, -1, 3)
coord1 = torch.gather(extended_coord, 1, index)
# nf x nloc x nnei x 3
coord1 = coord1.view(n_nf, n_nloc, n_nnei, 3)
# nf x nloc x nnei
rr = torch.linalg.norm(coord0[:, :, None, :] - coord1, dim=-1)
rr = torch.where(m_real_nei, rr, float("inf"))
rr, nlist_mapping = torch.sort(rr, dim=-1)
nlist = torch.gather(nlist, 2, nlist_mapping)
nlist = torch.where(rr > rcut, -1, nlist)
nlist = nlist[..., :nnei]
else: # n_nnei == nnei:
pass # great!
assert nlist.shape[-1] == nnei
return nlist
def do_grad_r(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is r_differentiable.
if var_name is None, returns if any of the variable is r_differentiable.
"""
return self.atomic_model.do_grad_r(var_name)
def do_grad_c(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is c_differentiable.
if var_name is None, returns if any of the variable is c_differentiable.
"""
return self.atomic_model.do_grad_c(var_name)
def serialize(self) -> dict:
return self.atomic_model.serialize()
@classmethod
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))
@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.atomic_model.get_dim_fparam()
@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.atomic_model.get_dim_aparam()
@torch.jit.export
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self.atomic_model.get_sel_type()
@torch.jit.export
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
If False, the shape is (nframes, nloc, ndim).
"""
return self.atomic_model.is_aparam_nall()
@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.atomic_model.get_rcut()
@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.atomic_model.get_type_map()
@torch.jit.export
def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.atomic_model.get_nsel()
@torch.jit.export
def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.atomic_model.get_nnei()
def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""Compute or load the statistics."""
return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path)
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.atomic_model.get_sel()
def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
2. uses a neighbor list that does not distinguish different atomic types.
If false, the model
1. assumes total number of atoms of each atom type aligned across frames;
2. uses a neighbor list that distinguishes different atomic types.
"""
return self.atomic_model.mixed_types()
def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
# directly call the forward_common method when no specific transform rule
return self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
return CM