# SPDX-License-Identifier: LGPL-3.0-or-later
import warnings
from typing import (
Any,
)
import torch
from deepmd.dpmodel.common import (
cast_precision,
)
from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP
from deepmd.dpmodel.descriptor.dpa2 import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
from deepmd.pt_expt.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt_expt.utils.update_sel import (
UpdateSel,
)
@BaseDescriptor.register("dpa2")
@torch_module
[docs]
class DescrptDPA2(DescrptDPA2DP):
[docs]
_update_sel_cls = UpdateSel
[docs]
def share_params(
self,
base_class: "DescrptDPA2",
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.
Level 0: share type_embedding, repinit, repinit_three_body,
g1_shape_tranform, and repformers.
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.repinit, self.repinit, model_prob)
if self.use_three_body and "repinit_three_body" in base_class._modules:
merge_env_stat(
base_class.repinit_three_body,
self.repinit_three_body,
model_prob,
)
merge_env_stat(base_class.repformers, self.repformers, model_prob)
self._modules["repinit"] = base_class._modules["repinit"]
if self.use_three_body and "repinit_three_body" in base_class._modules:
self._modules["repinit_three_body"] = base_class._modules[
"repinit_three_body"
]
self._modules["g1_shape_tranform"] = base_class._modules[
"g1_shape_tranform"
]
self._modules["repformers"] = base_class._modules["repformers"]
if "tebd_transform" in base_class._modules:
self._modules["tebd_transform"] = base_class._modules["tebd_transform"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError
[docs]
def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Enable compression for the DPA2 descriptor.
Compression applies to the repinit block (DescrptBlockSeAtten).
When attn_layer == 0, the geometric embedding is tabulated.
Type embedding outputs are always precomputed.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
if self.compress:
raise ValueError("Compression is already enabled.")
if self.repinit.resnet_dt:
raise RuntimeError(
"Model compression error: repinit resnet_dt must be false!"
)
for tt in self.repinit.exclude_types:
if (tt[0] not in range(self.repinit.ntypes)) or (
tt[1] not in range(self.repinit.ntypes)
):
raise RuntimeError(
"Repinit exclude types"
+ str(tt)
+ " must within the number of atomic types "
+ str(self.repinit.ntypes)
+ "!"
)
if (
self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types)
== 0
):
raise RuntimeError(
"Repinit empty embedding-nets are not supported in model compression!"
)
if self.repinit.tebd_input_mode != "strip":
raise RuntimeError(
"Cannot compress model when repinit tebd_input_mode != 'strip'"
)
# Precompute type embedding data for repinit
self._store_type_embd_data()
if self.repinit.attn_layer == 0:
# Build geometric embedding table
repinit_data = self.repinit.serialize()
self.table = DPTabulate(
self.repinit,
repinit_data["neuron"],
repinit_data.get("type_one_side", False),
repinit_data.get("exclude_types", []),
repinit_data["activation_function"],
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)
self._store_compress_data()
self.geo_compress = True
else:
warnings.warn(
"Attention layer is not 0, only type embedding is compressed. "
"Geometric part is not compressed.",
UserWarning,
stacklevel=2,
)
self.geo_compress = False
self.compress = True
[docs]
def _store_compress_data(self) -> None:
"""Store tabulated data as buffers for the compressed geometric embedding."""
table_data = self.table.data
table_config = self.table_config
lower = self.lower
upper = self.upper
prec = self.repinit.mean.dtype
net_key = "filter_net"
info = torch.as_tensor(
[
lower[net_key],
upper[net_key],
upper[net_key] * table_config[0],
table_config[1],
table_config[2],
table_config[3],
],
dtype=prec,
device="cpu",
)
tensor_data = table_data[net_key].to(dtype=prec)
self.compress_data = torch.nn.ParameterList(
[torch.nn.Parameter(tensor_data, requires_grad=False)]
)
self.compress_info = torch.nn.ParameterList(
[torch.nn.Parameter(info, requires_grad=False)]
)
[docs]
def _store_type_embd_data(self) -> None:
"""Precompute type embedding outputs for repinit and store as a buffer."""
with torch.no_grad():
# type_embedding.call() returns (ntypes+1) x tebd_dim (with padding)
full_embd = self.type_embedding.call()
nt, t_dim = full_embd.shape
if self.repinit.type_one_side:
# One-side: only neighbor types
# (ntypes+1) x tebd_dim -> (ntypes+1) x ng
embd_tensor = self.repinit.embeddings_strip[0].call(full_embd).detach()
else:
# Two-side: all (ntypes+1)^2 type pair combinations
# Build [neighbor, center] combinations
embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape(
-1, t_dim * 2
)
# ((ntypes+1)^2) x ng
embd_tensor = (
self.repinit.embeddings_strip[0].call(two_side_embd).detach()
)
torch.nn.Module.register_buffer(self, "type_embd_data", embd_tensor)
@cast_precision
[docs]
def call(
self,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
comm_dict: dict | None = None,
charge_spin: torch.Tensor | None = None,
) -> Any:
if not self.compress:
return DescrptDPA2DP.call.__wrapped__(
self,
coord_ext,
atype_ext,
nlist,
mapping,
fparam,
comm_dict=comm_dict,
)
# Compressed path is local-only (no message passing during compress).
return self._call_compressed(coord_ext, atype_ext, nlist, mapping)
[docs]
def _call_compressed(
self,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
) -> Any:
"""Compressed forward for DPA2 descriptor.
The repinit forward is done inline with compressed ops,
then the rest (g1_shape_transform, repformers, etc.) proceeds normally.
"""
use_three_body = self.use_three_body
nframes, nloc, _nnei = nlist.shape
nall = coord_ext.view(nframes, -1).shape[1] // 3
# Build multiple neighbor lists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
nlist,
self.rcut_list,
self.nsel_list,
)
# Type embedding
type_embedding = self.type_embedding.call()
g1_ext = type_embedding[atype_ext.reshape(-1).to(torch.long)].reshape(
nframes, nall, self.tebd_dim
)
g1_inp = g1_ext[:, :nloc, :]
# Compressed repinit forward
nlist_repinit = nlist_dict[
get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel())
]
g1 = self._compressed_repinit_forward(
coord_ext, atype_ext, nlist_repinit, nframes, nloc, nall, type_embedding
)
# Three-body (not compressed, call normally)
if use_three_body:
assert self.repinit_three_body is not None
g1_three_body, __, __, __, __ = self.repinit_three_body(
nlist_dict[
get_multiple_nlist_key(
self.repinit_three_body.get_rcut(),
self.repinit_three_body.get_nsel(),
)
],
coord_ext,
atype_ext,
g1_ext,
mapping,
type_embedding=type_embedding,
)
g1 = torch.cat([g1, g1_three_body], dim=-1)
# Linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# Mapping g1 to extended region for repformers
assert mapping is not None
mapping_ext = mapping.view(nframes, nall, 1).expand(-1, -1, g1.shape[-1])
g1_ext = torch.gather(g1, 1, mapping_ext)
# Repformers (not compressed)
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
get_multiple_nlist_key(
self.repformers.get_rcut(), self.repformers.get_nsel()
)
],
coord_ext,
atype_ext,
g1_ext,
mapping,
)
# Concat type embedding at output if needed
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
[docs]
def _compressed_repinit_forward(
self,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
nlist: torch.Tensor,
nframes: int,
nloc: int,
nall: int,
type_embedding: torch.Tensor,
) -> torch.Tensor:
"""Compressed forward for the repinit block.
Same logic as DPA1's _call_compressed but only produces the g1 output
(nf x nloc x (ng x axis_neuron)), without rot_mat/sw returns.
Parameters
----------
coord_ext
Extended coordinates. shape: nf x (nall x 3)
atype_ext
Extended atom types. shape: nf x nall
nlist
Neighbor list for repinit. shape: nf x nloc x nnei_repinit
nframes
Number of frames.
nloc
Number of local atoms.
nall
Number of all atoms (local + ghost).
type_embedding
Full type embedding. shape: (ntypes+1) x tebd_dim
Returns
-------
torch.Tensor
Repinit output. shape: nf x nloc x (ng x axis_neuron)
"""
# env_mat: nf x nloc x nnei x 4
rr, _diff, sw = self.repinit.env_mat.call(
coord_ext,
atype_ext,
nlist,
self.repinit.mean[...],
self.repinit.stddev[...],
)
nf, nloc_r, nnei, _ = rr.shape
ng = self.repinit.neuron[-1]
nfnl = nf * nloc_r
# Exclude mask and nlist processing
exclude_mask = self.repinit.emask.build_type_exclude_mask(nlist, atype_ext)
exclude_mask = exclude_mask.view(nfnl, nnei)
nlist = nlist.view(nfnl, nnei)
exclude_mask = exclude_mask.to(torch.bool)
nlist = torch.where(exclude_mask, nlist, torch.full_like(nlist, -1))
nlist_mask = nlist != -1
nlist_masked = torch.where(nlist_mask, nlist, torch.zeros_like(nlist))
# nfnl x nnei x 1
sw = torch.where(
nlist_mask[:, :, None],
sw.view(nfnl, nnei, 1),
torch.zeros(nfnl, nnei, 1, dtype=sw.dtype, device=sw.device),
)
# nfnl x nnei x 4
rr = rr.view(nfnl, nnei, 4)
rr = rr * exclude_mask[:, :, None].to(rr.dtype)
# nfnl x nnei x 1
ss = rr[:, :, :1]
# Type embedding lookup from precomputed buffer
ntypes_with_padding = type_embedding.shape[0]
# nf x (nloc x nnei)
nlist_index = nlist_masked.view(nf, nloc_r * nnei)
# nf x (nloc x nnei)
nei_type = torch.gather(atype_ext, dim=1, index=nlist_index)
if self.repinit.type_one_side:
# (nf*nl*nnei,) -> (nf*nl*nnei, ng)
gg_t = self.type_embd_data[nei_type.view(-1).to(torch.long)]
else:
atype = atype_ext[:, :nloc_r]
idx_i = torch.tile(
atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei]
).view(-1)
idx_j = nei_type.view(-1)
idx = (idx_i + idx_j).to(torch.long)
# (nf x nl x nnei) x ng
gg_t = self.type_embd_data[idx]
# (nf x nl) x nnei x ng
gg_t = gg_t.view(nfnl, nnei, ng)
if self.repinit.smooth:
gg_t = gg_t * sw.view(nfnl, self.repinit.nnei, 1)
if self.geo_compress:
# Flatten for tabulate op
ss_flat = ss.reshape(-1, 1)
gg_t_flat = gg_t.reshape(-1, gg_t.size(-1))
is_sorted = len(self.repinit.exclude_types) == 0
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
self.compress_data[0].contiguous(),
self.compress_info[0].cpu().contiguous(),
ss_flat.contiguous(),
rr.contiguous(),
gg_t_flat.contiguous(),
self.repinit.neuron[-1],
is_sorted,
)[0]
else:
# No geometric compression, run embedding net + attention
# nfnl x nnei x ng
gg_s = self.repinit.embeddings[0].call(ss)
# nfnl x nnei x ng
gg = gg_s * gg_t + gg_s
input_r = torch.nn.functional.normalize(
rr.view(-1, self.repinit.nnei, 4)[:, :, 1:4], dim=-1
)
gg = self.repinit.dpa1_attention(gg, nlist_mask, input_r=input_r, sw=sw)
# nfnl x 4 x ng
xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg)
xyz_scatter = xyz_scatter / self.repinit.nnei
# nfnl x ng x 4
xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)
# nfnl x 4 x axis_neuron
xyz_scatter_2 = xyz_scatter[:, :, 0 : self.repinit.axis_neuron]
# nfnl x ng x axis_neuron
result = torch.matmul(xyz_scatter_1, xyz_scatter_2)
# nf x nloc x (ng x axis_neuron)
result = result.view(nf, nloc_r, ng * self.repinit.axis_neuron)
return result