Source code for deepmd.pt.model.descriptor.dpa1

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
)

import torch

from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.pt.model.network.mlp import (
    NetworkCollection,
)
from deepmd.pt.model.network.network import (
    TypeEmbedNet,
    TypeEmbedNetConsistent,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.env import (
    RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.update_sel import (
    UpdateSel,
)
from deepmd.utils.path import (
    DPPath,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

from .base_descriptor import (
    BaseDescriptor,
)
from .se_atten import (
    DescrptBlockSeAtten,
    NeighborGatedAttention,
)


@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
[docs] class DescrptDPA1(BaseDescriptor, torch.nn.Module): r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by .. math:: \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: .. math:: \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), .. math:: \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), .. math:: \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` is the index of the attention layer. The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, is chosen as the two-body embedding matrix. Then the scaled dot-product attention method is adopted: .. math:: A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. In the original attention method, one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, with :math:`\sqrt{d_{k}}` being the normalization temperature. This is slightly modified to incorporate the angular information: .. math:: \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` and :math:`\odot` means element-wise multiplication. Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] .. math:: \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). Parameters ---------- rcut: float The cut-off radius :math:`r_c` rcut_smth: float From where the environment matrix should be smoothed :math:`r_s` sel : list[int], int list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius int: the total maxmum number of atoms in the cut-off radius ntypes : int Number of element types neuron : list[int] Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` axis_neuron: int Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) tebd_dim: int Dimension of the type embedding tebd_input_mode: str The input mode of the type embedding. Supported modes are ["concat", "strip"]. - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) trainable: bool If the weights of this descriptors are trainable. trainable_ln: bool Whether to use trainable shift and scale weights in layer normalization. ln_eps: float, Optional The epsilon value for layer normalization. type_one_side: bool If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'. attn: int Hidden dimension of the attention vectors attn_layer: int Number of attention layers attn_dotr: bool If dot the angular gate to the attention weights attn_mask: bool (Only support False to keep consistent with other backend references.) (Not used in this version. True option is not implemented.) If mask the diagonal of attention weights exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. env_protection: float Protection parameter to prevent division by zero errors during environment matrix calculations. set_davg_zero: bool Set the shift of embedding net input to zero. activation_function: str The activation function in the embedding net. Supported options are |ACTIVATION_FN| precision: str The precision of the embedding net parameters. Supported options are |PRECISION| scaling_factor: float The scaling factor of normalization in calculations of attention weights. If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 normalize: bool Whether to normalize the hidden vectors in attention weights calculation. temperature: float If not None, the scaling of attention weights is `temperature` itself. smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. stripped_type_embedding: bool, Optional (Deprecated, kept only for compatibility.) Whether to strip the type embedding into a separate embedding network. Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'. The default value is `None`, which means the `tebd_input_mode` setting will be used instead. spin (Only support None to keep consistent with other backend references.) (Not used in this version. Not-none option is not implemented.) The old implementation of deepspin. Limitations ----------- The currently implementation will not support the following deprecated features 1. spin is not None 2. attn_mask == True References ---------- .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. arXiv preprint arXiv:2208.08236. """ def __init__( self, rcut: float, rcut_smth: float, sel: Union[List[int], int], ntypes: int, neuron: list = [25, 50, 100], axis_neuron: int = 16, tebd_dim: int = 8, tebd_input_mode: str = "concat", set_davg_zero: bool = True, attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, activation_function: str = "tanh", precision: str = "float64", resnet_dt: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, scaling_factor: int = 1.0, normalize=True, temperature=None, concat_output_tebd: bool = True, trainable: bool = True, trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, smooth_type_embedding: bool = True, type_one_side: bool = False, stripped_type_embedding: Optional[bool] = None, # not implemented spin=None, type: Optional[str] = None, seed: Optional[int] = None, old_impl: bool = False, ): super().__init__() # Ensure compatibility with the deprecated stripped_type_embedding option. if stripped_type_embedding is not None: # Use the user-set stripped_type_embedding parameter first tebd_input_mode = "strip" if stripped_type_embedding else "concat" if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") if attn_mask: raise NotImplementedError( "old implementation of attn_mask is not supported." ) # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 del type, spin, attn_mask self.se_atten = DescrptBlockSeAtten( rcut, rcut_smth, sel, ntypes, neuron=neuron, axis_neuron=axis_neuron, tebd_dim=tebd_dim, tebd_input_mode=tebd_input_mode, set_davg_zero=set_davg_zero, attn=attn, attn_layer=attn_layer, attn_dotr=attn_dotr, attn_mask=False, activation_function=activation_function, precision=precision, resnet_dt=resnet_dt, scaling_factor=scaling_factor, normalize=normalize, temperature=temperature, smooth=smooth_type_embedding, type_one_side=type_one_side, exclude_types=exclude_types, env_protection=env_protection, trainable_ln=trainable_ln, ln_eps=ln_eps, old_impl=old_impl, ) self.type_embedding = TypeEmbedNet(ntypes, tebd_dim, precision=precision) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd # set trainable for param in self.parameters(): param.requires_grad = trainable
[docs] def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.se_atten.get_rcut()
[docs] def get_nsel(self) -> int: """Returns the number of selected atoms in the cut-off radius.""" return self.se_atten.get_nsel()
[docs] def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.se_atten.get_sel()
[docs] def get_ntypes(self) -> int: """Returns the number of element types.""" return self.se_atten.get_ntypes()
[docs] def get_dim_out(self) -> int: """Returns the output dimension.""" ret = self.se_atten.get_dim_out() if self.concat_output_tebd: ret += self.tebd_dim return ret
[docs] def get_dim_emb(self) -> int: return self.se_atten.dim_emb
[docs] def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; 2. requires a neighbor list that does not distinguish different atomic types. If false, the discriptor 1. assumes total number of atoms of each atom type aligned across frames; 2. requires a neighbor list that distinguishes different atomic types. """ return self.se_atten.mixed_types()
[docs] def share_params(self, base_class, shared_level, resume=False): """ Share the parameters of self to the base_class with shared_level during multitask training. If not start from checkpoint (resume is False), some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. """ assert ( self.__class__ == base_class.__class__ ), "Only descriptors of the same type can share params!" # For DPA1 descriptors, the user-defined share-level # shared_level: 0 # share all parameters in both type_embedding and se_atten if shared_level == 0: self._modules["type_embedding"] = base_class._modules["type_embedding"] self.se_atten.share_params(base_class.se_atten, 0, resume=resume) # shared_level: 1 # share all parameters in type_embedding elif shared_level == 1: self._modules["type_embedding"] = base_class._modules["type_embedding"] # Other shared levels else: raise NotImplementedError
@property
[docs] def dim_out(self): return self.get_dim_out()
@property
[docs] def dim_emb(self): return self.get_dim_emb()
[docs] def compute_input_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], path: Optional[DPPath] = None, ): """ Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. Parameters ---------- merged : Union[Callable[[], List[dict]], List[dict]] - List[dict]: A list of data samples from various data systems. Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` originating from the `i`-th data system. - Callable[[], List[dict]]: A lazy function that returns data samples in the above format only when needed. Since the sampling process can be slow and memory-intensive, the lazy function helps by only sampling once. path : Optional[DPPath] The path to the stat file. """ return self.se_atten.compute_input_stats(merged, path)
[docs] def set_stat_mean_and_stddev( self, mean: torch.Tensor, stddev: torch.Tensor, ) -> None: self.se_atten.mean = mean self.se_atten.stddev = stddev
[docs] def serialize(self) -> dict: obj = self.se_atten data = { "@class": "Descriptor", "type": "dpa1", "@version": 1, "rcut": obj.rcut, "rcut_smth": obj.rcut_smth, "sel": obj.sel, "ntypes": obj.ntypes, "neuron": obj.neuron, "axis_neuron": obj.axis_neuron, "tebd_dim": obj.tebd_dim, "tebd_input_mode": obj.tebd_input_mode, "set_davg_zero": obj.set_davg_zero, "attn": obj.attn_dim, "attn_layer": obj.attn_layer, "attn_dotr": obj.attn_dotr, "attn_mask": False, "activation_function": obj.activation_function, "resnet_dt": obj.resnet_dt, "scaling_factor": obj.scaling_factor, "normalize": obj.normalize, "temperature": obj.temperature, "trainable_ln": obj.trainable_ln, "ln_eps": obj.ln_eps, "smooth_type_embedding": obj.smooth, "type_one_side": obj.type_one_side, "concat_output_tebd": self.concat_output_tebd, # make deterministic "precision": RESERVED_PRECISON_DICT[obj.prec], "embeddings": obj.filter_layers.serialize(), "attention_layers": obj.dpa1_attention.serialize(), "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), "type_embedding": self.type_embedding.embedding.serialize(), "exclude_types": obj.exclude_types, "env_protection": obj.env_protection, "@variables": { "davg": obj["davg"].detach().cpu().numpy(), "dstd": obj["dstd"].detach().cpu().numpy(), }, ## to be updated when the options are supported. "trainable": True, "spin": None, } if obj.tebd_input_mode in ["strip"]: data.update({"embeddings_strip": obj.filter_layers_strip.serialize()}) return data
@classmethod
[docs] def deserialize(cls, data: dict) -> "DescrptDPA1": data = data.copy() check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") data.pop("type") variables = data.pop("@variables") embeddings = data.pop("embeddings") type_embedding = data.pop("type_embedding") attention_layers = data.pop("attention_layers") env_mat = data.pop("env_mat") tebd_input_mode = data["tebd_input_mode"] if tebd_input_mode in ["strip"]: embeddings_strip = data.pop("embeddings_strip") else: embeddings_strip = None obj = cls(**data) def t_cvt(xx): return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE) obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( type_embedding ) obj.se_atten["davg"] = t_cvt(variables["davg"]) obj.se_atten["dstd"] = t_cvt(variables["dstd"]) obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) if tebd_input_mode in ["strip"]: obj.se_atten.filter_layers_strip = NetworkCollection.deserialize( embeddings_strip ) obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( attention_layers ) return obj
[docs] def forward( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): """Compute the descriptor. Parameters ---------- extended_coord The extended coordinates of atoms. shape: nf x (nallx3) extended_atype The extended aotm types. shape: nf x nall nlist The neighbor list. shape: nf x nloc x nnei mapping The index mapping, not required by this descriptor. comm_dict The data needed for communication for parallel inference. Returns ------- descriptor The descriptor. shape: nf x nloc x (ng x axis_neuron) gr The rotationally equivariant and permutationally invariant single particle representation. shape: nf x nloc x ng x 3 g2 The rotationally invariant pair-partical representation. shape: nf x nloc x nnei x ng h2 The rotationally equivariant pair-partical representation. shape: nf x nloc x nnei x 3 sw The smooth switch function. shape: nf x nloc x nnei """ del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] g1, g2, h2, rot_mat, sw = self.se_atten( nlist, extended_coord, extended_atype, g1_ext, mapping=None, ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) return g1, rot_mat, g2, h2, sw
@classmethod
[docs] def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ local_jdata_cpy = local_jdata.copy() return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True)