# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (
List,
Optional,
Union,
)
import torch
from deepmd.dpmodel import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.model.task.fitting import (
GeneralFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.version import (
check_version_compatibility,
)
[docs]
log = logging.getLogger(__name__)
@GeneralFitting.register("polar")
[docs]
class PolarFittingNet(GeneralFitting):
"""Construct a polar fitting net.
Parameters
----------
ntypes : int
Element count.
dim_descrpt : int
Embedding width per atom.
embedding_width : int
The dimension of rotation matrix, m1.
neuron : List[int]
Number of neurons in each hidden layers of the fitting net.
resnet_dt : bool
Using time-step in the ResNet construction.
numb_fparam : int
Number of frame parameters.
numb_aparam : int
Number of atomic parameters.
activation_function : str
Activation function.
precision : str
Numerical precision.
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
rcond : float, optional
The condition number for the regression of atomic energy.
seed : int, optional
Random seed.
fit_diag : bool
Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to
normal polarizability matrix by contracting with the rotation matrix.
scale : List[float]
The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i]
shift_diag : bool
Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale.
"""
def __init__(
self,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
neuron: List[int] = [128, 128, 128],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
fit_diag: bool = True,
scale: Optional[Union[List[float], float]] = None,
shift_diag: bool = True,
**kwargs,
):
self.embedding_width = embedding_width
self.fit_diag = fit_diag
self.scale = scale
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
else:
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = torch.tensor(
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = torch.zeros(
ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
super().__init__(
var_name="polar",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
**kwargs,
)
self.old_impl = False # this only supports the new implementation.
[docs]
def _net_out_dim(self):
"""Set the FittingNet output dim."""
return (
self.embedding_width
if self.fit_diag
else self.embedding_width * self.embedding_width
)
[docs]
def __setitem__(self, key, value):
if key in ["constant_matrix"]:
self.constant_matrix = value
else:
super().__setitem__(key, value)
[docs]
def __getitem__(self, key):
if key in ["constant_matrix"]:
return self.constant_matrix
else:
return super().__getitem__(key)
[docs]
def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 2
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
data["@variables"]["scale"] = to_numpy_array(self.scale)
data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix)
return data
@classmethod
[docs]
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("var_name", None)
return super().deserialize(data)
[docs]
def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
"polarizability",
[3, 3],
reduciable=True,
r_differentiable=False,
c_differentiable=False,
),
]
)
[docs]
def forward(
self,
descriptor: torch.Tensor,
atype: torch.Tensor,
gr: Optional[torch.Tensor] = None,
g2: Optional[torch.Tensor] = None,
h2: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
nframes, nloc, _ = descriptor.shape
assert (
gr is not None
), "Must provide the rotation matrix for polarizability fitting."
# (nframes, nloc, _net_out_dim)
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * self.scale[atype]
gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3)
if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
out = torch.einsum("ij,ijk->ijk", out, gr)
else:
out = out.reshape(-1, self.embedding_width, self.embedding_width)
out = (out + out.transpose(1, 2)) / 2
out = torch.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = torch.einsum(
"bim,bmj->bij", gr.transpose(1, 2), out
) # (nframes * nloc, 3, 3)
out = out.view(nframes, nloc, 3, 3)
return {"polarizability": out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
# make jit happy with torch 2.0.0
[docs]
exclude_types: List[int]