Source code for

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

import torch

from import (
from import (
from import (
from import (
from import (
from import (
from import (
from deepmd.utils.env_mat_stat import (
from deepmd.utils.path import (

from .repformer_layer import (

[docs] mydtype = env.GLOBAL_PT_FLOAT_PRECISION
[docs] mydev = env.DEVICE
[docs] def torch_linear(*args, **kwargs): return torch.nn.Linear(*args, **kwargs, dtype=mydtype, device=mydev)
[docs] simple_linear = SimpleLinear
[docs] mylinear = simple_linear
if not hasattr(torch.ops.deepmd, "border_op"):
[docs] def border_op( argument0, argument1, argument2, argument3, argument4, argument5, argument6, argument7, argument8, ) -> torch.Tensor: raise NotImplementedError( "border_op is not available since customized PyTorch OP library is not built when freezing the model." )
# Note: this hack cannot actually save a model that can be runned using LAMMPS. torch.ops.deepmd.border_op = border_op @DescriptorBlock.register("se_repformer") @DescriptorBlock.register("se_uni")
[docs] class DescrptBlockRepformers(DescriptorBlock): def __init__( self, rcut, rcut_smth, sel: int, ntypes: int, nlayers: int = 3, g1_dim=128, g2_dim=16, axis_dim: int = 4, direct_dist: bool = False, do_bn_mode: str = "no", bn_momentum: float = 0.1, update_g1_has_conv: bool = True, update_g1_has_drrd: bool = True, update_g1_has_grrg: bool = True, update_g1_has_attn: bool = True, update_g2_has_g1g1: bool = True, update_g2_has_attn: bool = True, update_h2: bool = False, attn1_hidden: int = 64, attn1_nhead: int = 4, attn2_hidden: int = 16, attn2_nhead: int = 4, attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", set_davg_zero: bool = True, # TODO smooth: bool = True, add_type_ebd_to_seq: bool = False, exclude_types: List[Tuple[int, int]] = [], env_protection: float = 0.0, type: Optional[str] = None, ): """ smooth: If strictly smooth, cannot be used with update_g1_has_attn add_type_ebd_to_seq: At the presence of seq_input (optional input to forward), whether or not add an type embedding to seq_input. If no seq_input is given, it has no effect. """ super().__init__() del type self.epsilon = 1e-4 # protection of 1./nnei self.rcut = rcut self.rcut_smth = rcut_smth self.ntypes = ntypes self.nlayers = nlayers sel = [sel] if isinstance(sel, int) else sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 # use full descriptor. assert len(sel) == 1 self.sel = sel self.sec = self.sel self.split_sel = self.sel self.axis_dim = axis_dim self.set_davg_zero = set_davg_zero self.g1_dim = g1_dim self.g2_dim = g2_dim self.act = ActivationFn(activation_function) self.direct_dist = direct_dist self.add_type_ebd_to_seq = add_type_ebd_to_seq # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection self.g2_embd = mylinear(1, self.g2_dim) layers = [] for ii in range(nlayers): layers.append( RepformerLayer( rcut, rcut_smth, sel, ntypes, self.g1_dim, self.g2_dim, axis_dim=self.axis_dim, update_chnnl_2=(ii != nlayers - 1), do_bn_mode=do_bn_mode, bn_momentum=bn_momentum, update_g1_has_conv=update_g1_has_conv, update_g1_has_drrd=update_g1_has_drrd, update_g1_has_grrg=update_g1_has_grrg, update_g1_has_attn=update_g1_has_attn, update_g2_has_g1g1=update_g2_has_g1g1, update_g2_has_attn=update_g2_has_attn, update_h2=update_h2, attn1_hidden=attn1_hidden, attn1_nhead=attn1_nhead, attn2_has_gate=attn2_has_gate, attn2_hidden=attn2_hidden, attn2_nhead=attn2_nhead, activation_function=activation_function, update_style=update_style, smooth=smooth, ) ) self.layers = torch.nn.ModuleList(layers) sshape = (self.ntypes, self.nnei, 4) mean = torch.zeros(sshape, dtype=mydtype, device=mydev) stddev = torch.ones(sshape, dtype=mydtype, device=mydev) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None
[docs] def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut
[docs] def get_nsel(self) -> int: """Returns the number of selected atoms in the cut-off radius.""" return sum(self.sel)
[docs] def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sel
[docs] def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes
[docs] def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out
[docs] def get_dim_in(self) -> int: """Returns the input dimension.""" return self.dim_in
[docs] def get_dim_emb(self) -> int: """Returns the embedding dimension g2.""" return self.g2_dim
[docs] def mixed_types(self) -> bool: """If true, the discriptor 1. assumes total number of atoms aligned across frames; 2. requires a neighbor list that does not distinguish different atomic types. If false, the discriptor 1. assumes total number of atoms of each atom type aligned across frames; 2. requires a neighbor list that distinguishes different atomic types. """ return True
[docs] def dim_out(self): """Returns the output dimension of this descriptor.""" return self.g1_dim
[docs] def dim_in(self): """Returns the atomic input dimension of this descriptor.""" return self.g1_dim
[docs] def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb()
[docs] def reinit_exclude( self, exclude_types: List[Tuple[int, int]] = [], ): self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
[docs] def forward( self, nlist: torch.Tensor, extended_coord: torch.Tensor, extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): if comm_dict is None: assert mapping is not None assert extended_atype_embd is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, atype, self.mean, self.stddev, self.rcut, self.rcut_smth, protection=self.env_protection, ) nlist_mask = nlist != -1 sw = torch.squeeze(sw, -1) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) # [nframes, nloc, tebd_dim] if comm_dict is None: assert isinstance(extended_atype_embd, torch.Tensor) # for jit atype_embd = extended_atype_embd[:, :nloc, :] assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] else: atype_embd = extended_atype_embd assert isinstance(atype_embd, torch.Tensor) # for jit g1 = self.act(atype_embd) # nb x nloc x nnei x 1, nb x nloc x nnei x 3 if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) else: g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nb x nloc x nnei x ng2 g2 = self.act(self.g2_embd(g2)) # set all padding positions to index of 0 # if the a neighbor is real or not is indicated by nlist_mask nlist[nlist == -1] = 0 # nb x nall x ng1 if comm_dict is None: assert mapping is not None mapping = ( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) ) for idx, ll in enumerate(self.layers): # g1: nb x nloc x ng1 # g1_ext: nb x nall x ng1 if comm_dict is None: assert mapping is not None g1_ext = torch.gather(g1, 1, mapping) else: n_padding = nall - nloc g1 = torch.nn.functional.pad( g1.squeeze(0), (0, 0, 0, n_padding), value=0.0 ) assert "send_list" in comm_dict assert "send_proc" in comm_dict assert "recv_proc" in comm_dict assert "send_num" in comm_dict assert "recv_num" in comm_dict assert "communicator" in comm_dict ret = torch.ops.deepmd.border_op( comm_dict["send_list"], comm_dict["send_proc"], comm_dict["recv_proc"], comm_dict["send_num"], comm_dict["recv_num"], g1, comm_dict["communicator"], torch.tensor(nloc), torch.tensor(nall - nloc), ) g1_ext = ret[0].unsqueeze(0) g1, g2, h2 = ll.forward( g1_ext, g2, h2, nlist, nlist_mask, sw, ) # uses the last layer. # nb x nloc x 3 x ng2 h2g2 = ll._cal_h2g2(g2, h2, nlist_mask, sw) # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw
[docs] def compute_input_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], path: Optional[DPPath] = None, ): """ Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. Parameters ---------- merged : Union[Callable[[], List[dict]], List[dict]] - List[dict]: A list of data samples from various data systems. Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` originating from the `i`-th data system. - Callable[[], List[dict]]: A lazy function that returns data samples in the above format only when needed. Since the sampling process can be slow and memory-intensive, the lazy function helps by only sampling once. path : Optional[DPPath] The path to the stat file. """ env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() if path is None or not path.is_dir(): if callable(merged): # only get data for once sampled = merged() else: sampled = merged else: sampled = [] env_mat_stat.load_or_compute_stats(sampled, path) self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
[docs] def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" if self.stats is None: raise RuntimeError( "The statistics of the descriptor has not been computed." ) return self.stats