Source code for deepmd.pt.model.atomic_model.polar_atomic_model
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)
import torch
from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)
from .dp_atomic_model import (
DPAtomicModel,
)
[docs]
class DPPolarAtomicModel(DPAtomicModel):
def __init__(
self, descriptor: Any, fitting: Any, type_map: Any, **kwargs: Any
) -> None:
if not isinstance(fitting, PolarFittingNet):
raise TypeError(
"fitting must be an instance of PolarFittingNet for DPPolarAtomicModel"
)
super().__init__(descriptor, fitting, type_map, **kwargs)
[docs]
def apply_out_stat(
self,
ret: dict[str, torch.Tensor],
atype: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Apply the stat to each atomic output.
Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc
"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
if self.fitting_net.shift_diag:
nframes, nloc = atype.shape
device = out_bias[self.bias_keys[0]].device
dtype = out_bias[self.bias_keys[0]].dtype
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = torch.zeros(ntypes, dtype=dtype, device=device)
temp = torch.mean(
torch.diagonal(
out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1
),
dim=-1,
)
modified_bias = temp[atype]
# (nframes, nloc, 1)
modified_bias = (
modified_bias.unsqueeze(-1)
* (self.fitting_net.scale.to(atype.device))[atype]
)
eye = torch.eye(3, dtype=dtype, device=device)
eye = eye.repeat(nframes, nloc, 1, 1)
# (nframes, nloc, 3, 3)
modified_bias = modified_bias.unsqueeze(-1) * eye
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + modified_bias
return ret