# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)
import numpy as np
import torch
from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.pt.model.descriptor import (
prod_env_mat,
)
from deepmd.pt.model.network.mlp import (
EmbeddingNet,
NetworkCollection,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
from .base_descriptor import (
BaseDescriptor,
)
@BaseDescriptor.register("se_e2_r")
@BaseDescriptor.register("se_r")
[docs]
class DescrptSeR(BaseDescriptor, torch.nn.Module):
def __init__(
self,
rcut,
rcut_smth,
sel,
neuron=[25, 50, 100],
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
old_impl: bool = False,
trainable: bool = True,
**kwargs,
):
super().__init__()
self.rcut = rcut
self.rcut_smth = rcut_smth
self.neuron = neuron
self.filter_neuron = self.neuron
self.set_davg_zero = set_davg_zero
self.activation_function = activation_function
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.resnet_dt = resnet_dt
self.old_impl = False # this does not support old implementation.
self.exclude_types = exclude_types
self.ntypes = len(sel)
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
self.sel = sel
self.sec = torch.tensor(
np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE
)
self.split_sel = self.sel
self.nnei = sum(sel)
self.ndescrpt = self.nnei * 1
wanted_shape = (self.ntypes, self.nnei, 1)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.filter_layers_old = None
self.filter_layers = None
filter_layers = NetworkCollection(
ndim=1, ntypes=len(sel), network_type="embedding_network"
)
# TODO: ndim=2 if type_one_side=False
for ii in range(self.ntypes):
filter_layers[(ii,)] = EmbeddingNet(
1,
self.filter_neuron,
activation_function=self.activation_function,
precision=self.precision,
resnet_dt=self.resnet_dt,
)
self.filter_layers = filter_layers
self.stats = None
# set trainable
for param in self.parameters():
param.requires_grad = trainable
[docs]
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
[docs]
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return sum(self.sel)
[docs]
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.sel
[docs]
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
[docs]
def get_dim_out(self) -> int:
"""Returns the output dimension."""
return self.neuron[-1]
[docs]
def get_dim_emb(self) -> int:
"""Returns the output dimension."""
raise NotImplementedError
[docs]
def get_dim_in(self) -> int:
"""Returns the input dimension."""
return 0
[docs]
def mixed_types(self) -> bool:
"""If true, the discriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the discriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return False
[docs]
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
assert (
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For SeR descriptors, the user-defined share-level
# shared_level: 0
if shared_level == 0:
# link buffers
if hasattr(self, "mean") and not resume:
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
# the following will successfully link all the params except buffers
for item in self._modules:
self._modules[item] = base_class._modules[item]
# Other shared levels
else:
raise NotImplementedError
[docs]
def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(
"The statistics of the descriptor has not been computed."
)
return self.stats
[docs]
def __setitem__(self, key, value):
if key in ("avg", "data_avg", "davg"):
self.mean = value
elif key in ("std", "data_std", "dstd"):
self.stddev = value
else:
raise KeyError(key)
[docs]
def __getitem__(self, key):
if key in ("avg", "data_avg", "davg"):
return self.mean
elif key in ("std", "data_std", "dstd"):
return self.stddev
else:
raise KeyError(key)
[docs]
def reinit_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
[docs]
def forward(
self,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
"""Compute the descriptor.
Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
del mapping
nloc = nlist.shape[1]
atype = atype_ext[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
coord_ext,
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
True,
protection=self.env_protection,
)
assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 1, self.filter_neuron[-1]], dtype=self.prec, device=coord_ext.device
)
# nfnl x nnei
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1)
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 1
ss = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
ss = ss * mm[:, :, None]
# nfnl x nt x ng
gg = ll.forward(ss)
gg = torch.mean(gg, dim=1).unsqueeze(1)
xyz_scatter += gg * (self.sel[ii] / self.nnei)
res_rescale = 1.0 / 5.0
result = xyz_scatter * res_rescale
result = result.view(-1, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw,
)
[docs]
def set_stat_mean_and_stddev(
self,
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
self.mean = mean
self.stddev = stddev
[docs]
def serialize(self) -> dict:
return {
"@class": "Descriptor",
"type": "se_r",
"@version": 1,
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function,
# make deterministic
"precision": RESERVED_PRECISON_DICT[self.prec],
"embeddings": self.filter_layers.serialize(),
"env_mat": DPEnvMat(self.rcut, self.rcut_smth).serialize(),
"exclude_types": self.exclude_types,
"env_protection": self.env_protection,
"@variables": {
"davg": self["davg"].detach().cpu().numpy(),
"dstd": self["dstd"].detach().cpu().numpy(),
},
## to be updated when the options are supported.
"trainable": True,
"type_one_side": True,
"spin": None,
}
@classmethod
[docs]
def deserialize(cls, data: dict) -> "DescrptSeR":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
obj = cls(**data)
def t_cvt(xx):
return torch.tensor(xx, dtype=obj.prec, device=env.DEVICE)
obj["davg"] = t_cvt(variables["davg"])
obj["dstd"] = t_cvt(variables["dstd"])
obj.filter_layers = NetworkCollection.deserialize(embeddings)
return obj
@classmethod
[docs]
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)