# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (
List,
Optional,
Tuple,
)
import numpy as np
import torch
from deepmd.dpmodel import (
FittingOutputDef,
OutputVariableDef,
fitting_check_output,
)
from deepmd.pt.model.network.network import (
ResidualDeep,
)
from deepmd.pt.model.task.fitting import (
Fitting,
GeneralFitting,
)
from deepmd.pt.model.task.invar_fitting import (
InvarFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.utils.version import (
check_version_compatibility,
)
[docs]
dtype = env.GLOBAL_PT_FLOAT_PRECISION
[docs]
log = logging.getLogger(__name__)
@Fitting.register("ener")
[docs]
class EnergyFittingNet(InvarFitting):
def __init__(
self,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [128, 128, 128],
bias_atom_e: Optional[torch.Tensor] = None,
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
**kwargs,
):
super().__init__(
"energy",
ntypes,
dim_descrpt,
1,
neuron=neuron,
bias_atom_e=bias_atom_e,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
**kwargs,
)
@classmethod
[docs]
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
[docs]
def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "ener",
}
# make jit happy with torch 2.0.0
[docs]
exclude_types: List[int]
@Fitting.register("direct_force")
@Fitting.register("direct_force_ener")
@fitting_check_output
[docs]
class EnergyFittingNetDirect(Fitting):
def __init__(
self,
ntypes,
dim_descrpt,
neuron,
bias_atom_e=None,
out_dim=1,
resnet_dt=True,
use_tebd=True,
return_energy=False,
**kwargs,
):
"""Construct a fitting net for energy.
Args:
- ntypes: Element count.
- embedding_width: Embedding width per atom.
- neuron: Number of neurons in each hidden layers of the fitting net.
- bias_atom_e: Average enery per atom for each element.
- resnet_dt: Using time-step in the ResNet construction.
"""
super().__init__()
self.ntypes = ntypes
self.dim_descrpt = dim_descrpt
self.use_tebd = use_tebd
self.out_dim = out_dim
if bias_atom_e is None:
bias_atom_e = np.zeros([self.ntypes])
if not use_tebd:
assert self.ntypes == len(bias_atom_e), "Element count mismatches!"
bias_atom_e = torch.tensor(bias_atom_e, device=env.DEVICE)
self.register_buffer("bias_atom_e", bias_atom_e)
filter_layers_dipole = []
for type_i in range(self.ntypes):
one = ResidualDeep(
type_i,
dim_descrpt,
neuron,
0.0,
out_dim=out_dim,
resnet_dt=resnet_dt,
)
filter_layers_dipole.append(one)
self.filter_layers_dipole = torch.nn.ModuleList(filter_layers_dipole)
self.return_energy = return_energy
filter_layers = []
if self.return_energy:
for type_i in range(self.ntypes):
bias_type = 0.0 if self.use_tebd else bias_atom_e[type_i]
one = ResidualDeep(
type_i, dim_descrpt, neuron, bias_type, resnet_dt=resnet_dt
)
filter_layers.append(one)
self.filter_layers = torch.nn.ModuleList(filter_layers)
if "seed" in kwargs:
torch.manual_seed(kwargs["seed"])
[docs]
def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
"energy",
[1],
reduciable=True,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"dforce",
[3],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
[docs]
def serialize(self) -> dict:
raise NotImplementedError
[docs]
def deserialize(cls) -> "EnergyFittingNetDirect":
raise NotImplementedError
[docs]
def forward(
self,
inputs: torch.Tensor,
atype: torch.Tensor,
gr: Optional[torch.Tensor] = None,
g2: Optional[torch.Tensor] = None,
h2: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, None]:
"""Based on embedding net output, alculate total energy.
Args:
- inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt].
- natoms: Tell atom count and element count. Its shape is [2+self.ntypes].
Returns
-------
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
"""
nframes, nloc, _ = inputs.size()
if self.use_tebd:
# if atype_tebd is not None:
# inputs = torch.concat([inputs, atype_tebd], dim=-1)
vec_out = self.filter_layers_dipole[0](
inputs
) # Shape is [nframes, nloc, m1]
assert list(vec_out.size()) == [nframes, nloc, self.out_dim]
# (nf x nloc) x 1 x od
vec_out = vec_out.view(-1, 1, self.out_dim)
assert gr is not None
# (nf x nloc) x od x 3
gr = gr.view(-1, self.out_dim, 3)
vec_out = (
torch.bmm(vec_out, gr).squeeze(-2).view(nframes, nloc, 3)
) # Shape is [nframes, nloc, 3]
else:
vec_out = torch.zeros_like(atype).unsqueeze(-1) # jit assertion
for type_i, filter_layer in enumerate(self.filter_layers_dipole):
mask = atype == type_i
vec_out_type = filter_layer(inputs) # Shape is [nframes, nloc, m1]
vec_out_type = vec_out_type * mask.unsqueeze(-1)
vec_out = vec_out + vec_out_type # Shape is [nframes, natoms[0], 1]
outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion
if self.return_energy:
if self.use_tebd:
atom_energy = self.filter_layers[0](inputs) + self.bias_atom_e[
atype
].unsqueeze(-1)
outs = outs + atom_energy # Shape is [nframes, natoms[0], 1]
else:
for type_i, filter_layer in enumerate(self.filter_layers):
mask = atype == type_i
atom_energy = filter_layer(inputs)
if not env.ENERGY_BIAS_TRAINABLE:
atom_energy = atom_energy + self.bias_atom_e[type_i]
atom_energy = atom_energy * mask.unsqueeze(-1)
outs = outs + atom_energy # Shape is [nframes, natoms[0], 1]
return {
"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION),
"dforce": vec_out,
}