Source code for deepmd.pt.model.descriptor.repformers
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)
import torch
from deepmd.pt.model.descriptor.descriptor import (
DescriptorBlock,
)
from deepmd.pt.model.descriptor.env_mat import (
prod_env_mat,
)
from deepmd.pt.model.network.network import (
SimpleLinear,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from .repformer_layer import (
RepformerLayer,
)
[docs]
def torch_linear(*args, **kwargs):
return torch.nn.Linear(*args, **kwargs, dtype=mydtype, device=mydev)
if not hasattr(torch.ops.deepmd, "border_op"):
[docs]
def border_op(
argument0,
argument1,
argument2,
argument3,
argument4,
argument5,
argument6,
argument7,
argument8,
) -> torch.Tensor:
raise NotImplementedError(
"border_op is not available since customized PyTorch OP library is not built when freezing the model."
)
# Note: this hack cannot actually save a model that can be runned using LAMMPS.
torch.ops.deepmd.border_op = border_op
@DescriptorBlock.register("se_repformer")
@DescriptorBlock.register("se_uni")
[docs]
class DescrptBlockRepformers(DescriptorBlock):
def __init__(
self,
rcut,
rcut_smth,
sel: int,
ntypes: int,
nlayers: int = 3,
g1_dim=128,
g2_dim=16,
axis_dim: int = 4,
direct_dist: bool = False,
do_bn_mode: str = "no",
bn_momentum: float = 0.1,
update_g1_has_conv: bool = True,
update_g1_has_drrd: bool = True,
update_g1_has_grrg: bool = True,
update_g1_has_attn: bool = True,
update_g2_has_g1g1: bool = True,
update_g2_has_attn: bool = True,
update_h2: bool = False,
attn1_hidden: int = 64,
attn1_nhead: int = 4,
attn2_hidden: int = 16,
attn2_nhead: int = 4,
attn2_has_gate: bool = False,
activation_function: str = "tanh",
update_style: str = "res_avg",
set_davg_zero: bool = True, # TODO
smooth: bool = True,
add_type_ebd_to_seq: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
type: Optional[str] = None,
):
"""
smooth:
If strictly smooth, cannot be used with update_g1_has_attn
add_type_ebd_to_seq:
At the presence of seq_input (optional input to forward),
whether or not add an type embedding to seq_input.
If no seq_input is given, it has no effect.
"""
super().__init__()
del type
self.epsilon = 1e-4 # protection of 1./nnei
self.rcut = rcut
self.rcut_smth = rcut_smth
self.ntypes = ntypes
self.nlayers = nlayers
sel = [sel] if isinstance(sel, int) else sel
self.nnei = sum(sel)
self.ndescrpt = self.nnei * 4 # use full descriptor.
assert len(sel) == 1
self.sel = sel
self.sec = self.sel
self.split_sel = self.sel
self.axis_dim = axis_dim
self.set_davg_zero = set_davg_zero
self.g1_dim = g1_dim
self.g2_dim = g2_dim
self.act = ActivationFn(activation_function)
self.direct_dist = direct_dist
self.add_type_ebd_to_seq = add_type_ebd_to_seq
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
self.g2_embd = mylinear(1, self.g2_dim)
layers = []
for ii in range(nlayers):
layers.append(
RepformerLayer(
rcut,
rcut_smth,
sel,
ntypes,
self.g1_dim,
self.g2_dim,
axis_dim=self.axis_dim,
update_chnnl_2=(ii != nlayers - 1),
do_bn_mode=do_bn_mode,
bn_momentum=bn_momentum,
update_g1_has_conv=update_g1_has_conv,
update_g1_has_drrd=update_g1_has_drrd,
update_g1_has_grrg=update_g1_has_grrg,
update_g1_has_attn=update_g1_has_attn,
update_g2_has_g1g1=update_g2_has_g1g1,
update_g2_has_attn=update_g2_has_attn,
update_h2=update_h2,
attn1_hidden=attn1_hidden,
attn1_nhead=attn1_nhead,
attn2_has_gate=attn2_has_gate,
attn2_hidden=attn2_hidden,
attn2_nhead=attn2_nhead,
activation_function=activation_function,
update_style=update_style,
smooth=smooth,
)
)
self.layers = torch.nn.ModuleList(layers)
sshape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(sshape, dtype=mydtype, device=mydev)
stddev = torch.ones(sshape, dtype=mydtype, device=mydev)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None
[docs]
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return sum(self.sel)
[docs]
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.sel
[docs]
def mixed_types(self) -> bool:
"""If true, the discriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the discriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return True
@property
@property
[docs]
def dim_in(self):
"""Returns the atomic input dimension of this descriptor."""
return self.g1_dim
@property
[docs]
def reinit_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
[docs]
def forward(
self,
nlist: torch.Tensor,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
assert mapping is not None
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
# nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
dmatrix, diff, sw = prod_env_mat(
extended_coord,
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
protection=self.env_protection,
)
nlist_mask = nlist != -1
sw = torch.squeeze(sw, -1)
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
# [nframes, nloc, tebd_dim]
if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
g2, h2 = torch.split(dmatrix, [1, 3], dim=-1)
else:
g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff
g2 = g2 / self.rcut
h2 = h2 / self.rcut
# nb x nloc x nnei x ng2
g2 = self.act(self.g2_embd(g2))
# set all padding positions to index of 0
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x ng1
if comm_dict is None:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
assert "send_num" in comm_dict
assert "recv_num" in comm_dict
assert "communicator" in comm_dict
ret = torch.ops.deepmd.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
comm_dict["recv_proc"],
comm_dict["send_num"],
comm_dict["recv_num"],
g1,
comm_dict["communicator"],
torch.tensor(nloc),
torch.tensor(nall - nloc),
)
g1_ext = ret[0].unsqueeze(0)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
h2,
nlist,
nlist_mask,
sw,
)
# uses the last layer.
# nb x nloc x 3 x ng2
h2g2 = ll._cal_h2g2(g2, h2, nlist_mask, sw)
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw
[docs]
def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
path: Optional[DPPath] = None,
):
"""
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
path : Optional[DPPath]
The path to the stat file.
"""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
self.mean.copy_(torch.tensor(mean, device=env.DEVICE))
self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
[docs]
def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(
"The statistics of the descriptor has not been computed."
)
return self.stats