# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)
import torch
from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.pt.model.network.mlp import (
NetworkCollection,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
TypeEmbedNetConsistent,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
from .base_descriptor import (
BaseDescriptor,
)
from .se_atten import (
DescrptBlockSeAtten,
NeighborGatedAttention,
)
@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
[docs]
class DescrptDPA1(BaseDescriptor, torch.nn.Module):
r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model.
This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by
.. math::
\mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<,
where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i`
after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor.
Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor.
To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`,
keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`,
and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained:
.. math::
\left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right),
.. math::
\left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right),
.. math::
\left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right),
where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations
that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l`
is the index of the attention layer.
The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`,
is chosen as the two-body embedding matrix.
Then the scaled dot-product attention method is adopted:
.. math::
A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l},
where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights.
In the original attention method,
one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`,
with :math:`\sqrt{d_{k}}` being the normalization temperature.
This is slightly modified to incorporate the angular information:
.. math::
\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T},
where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates,
:math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}`
and :math:`\odot` means element-wise multiplication.
Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix
:math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1]
.. math::
\mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})).
Parameters
----------
rcut: float
The cut-off radius :math:`r_c`
rcut_smth: float
From where the environment matrix should be smoothed :math:`r_s`
sel : list[int], int
list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius
int: the total maxmum number of atoms in the cut-off radius
ntypes : int
Number of element types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
axis_neuron: int
Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
trainable: bool
If the weights of this descriptors are trainable.
trainable_ln: bool
Whether to use trainable shift and scale weights in layer normalization.
ln_eps: float, Optional
The epsilon value for layer normalization.
type_one_side: bool
If 'False', type embeddings of both neighbor and central atoms are considered.
If 'True', only type embeddings of neighbor atoms are considered.
Default is 'False'.
attn: int
Hidden dimension of the attention vectors
attn_layer: int
Number of attention layers
attn_dotr: bool
If dot the angular gate to the attention weights
attn_mask: bool
(Only support False to keep consistent with other backend references.)
(Not used in this version. True option is not implemented.)
If mask the diagonal of attention weights
exclude_types : List[List[int]]
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
env_protection: float
Protection parameter to prevent division by zero errors during environment matrix calculations.
set_davg_zero: bool
Set the shift of embedding net input to zero.
activation_function: str
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision: str
The precision of the embedding net parameters. Supported options are |PRECISION|
scaling_factor: float
The scaling factor of normalization in calculations of attention weights.
If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5
normalize: bool
Whether to normalize the hidden vectors in attention weights calculation.
temperature: float
If not None, the scaling of attention weights is `temperature` itself.
smooth_type_embedding: bool
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.
Limitations
-----------
The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
References
----------
.. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022.
DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation.
arXiv preprint arXiv:2208.08236.
"""
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[List[int], int],
ntypes: int,
neuron: list = [25, 50, 100],
axis_neuron: int = 16,
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
set_davg_zero: bool = True,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
scaling_factor: int = 1.0,
normalize=True,
temperature=None,
concat_output_tebd: bool = True,
trainable: bool = True,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
smooth_type_embedding: bool = True,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
# not implemented
spin=None,
type: Optional[str] = None,
seed: Optional[int] = None,
old_impl: bool = False,
):
super().__init__()
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
del type, spin, attn_mask
self.se_atten = DescrptBlockSeAtten(
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode=tebd_input_mode,
set_davg_zero=set_davg_zero,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=False,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
smooth=smooth_type_embedding,
type_one_side=type_one_side,
exclude_types=exclude_types,
env_protection=env_protection,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
old_impl=old_impl,
)
self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision)
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
# set trainable
for param in self.parameters():
param.requires_grad = trainable
[docs]
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.se_atten.get_rcut()
[docs]
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return self.se_atten.get_nsel()
[docs]
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.se_atten.get_sel()
[docs]
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.se_atten.get_ntypes()
[docs]
def get_dim_out(self) -> int:
"""Returns the output dimension."""
ret = self.se_atten.get_dim_out()
if self.concat_output_tebd:
ret += self.tebd_dim
return ret
[docs]
def get_dim_emb(self) -> int:
return self.se_atten.dim_emb
[docs]
def mixed_types(self) -> bool:
"""If true, the discriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the discriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return self.se_atten.mixed_types()
[docs]
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
assert (
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For DPA1 descriptors, the user-defined share-level
# shared_level: 0
# share all parameters in both type_embedding and se_atten
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.se_atten.share_params(base_class.se_atten, 0, resume=resume)
# shared_level: 1
# share all parameters in type_embedding
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
# Other shared levels
else:
raise NotImplementedError
@property
[docs]
def dim_out(self):
return self.get_dim_out()
@property
[docs]
def dim_emb(self):
return self.get_dim_emb()
[docs]
def set_stat_mean_and_stddev(
self,
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
self.se_atten.mean = mean
self.se_atten.stddev = stddev
[docs]
def serialize(self) -> dict:
obj = self.se_atten
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"tebd_input_mode": obj.tebd_input_mode,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn_dim,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": False,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"trainable_ln": obj.trainable_ln,
"ln_eps": obj.ln_eps,
"smooth_type_embedding": obj.smooth,
"type_one_side": obj.type_one_side,
"concat_output_tebd": self.concat_output_tebd,
# make deterministic
"precision": RESERVED_PRECISON_DICT[obj.prec],
"embeddings": obj.filter_layers.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"type_embedding": self.type_embedding.embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
## to be updated when the options are supported.
"trainable": True,
"spin": None,
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.filter_layers_strip.serialize()})
return data
@classmethod
[docs]
def deserialize(cls, data: dict) -> "DescrptDPA1":
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)
def t_cvt(xx):
return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE)
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
type_embedding
)
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
return obj
[docs]
def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.
Parameters
----------
extended_coord
The extended coordinates of atoms. shape: nf x (nallx3)
extended_atype
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei
"""
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, g2, h2, rot_mat, sw = self.se_atten(
nlist,
extended_coord,
extended_atype,
g1_ext,
mapping=None,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
@classmethod
[docs]
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True)