# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from abc import (
abstractmethod,
)
from typing import (
List,
Optional,
Union,
)
import numpy as np
import torch
from deepmd.pt.model.network.mlp import (
FittingNet,
NetworkCollection,
)
from deepmd.pt.model.network.network import (
ResidualDeep,
)
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
PRECISION_DICT,
)
from deepmd.pt.utils.exclude_mask import (
AtomExcludeMask,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)
[docs]
dtype = env.GLOBAL_PT_FLOAT_PRECISION
[docs]
log = logging.getLogger(__name__)
[docs]
class Fitting(torch.nn.Module, BaseFitting):
# plugin moved to BaseFitting
def __new__(cls, *args, **kwargs):
if cls is Fitting:
return BaseFitting.__new__(BaseFitting, *args, **kwargs)
return super().__new__(cls)
[docs]
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
assert (
self.__class__ == base_class.__class__
), "Only fitting nets of the same type can share params!"
if shared_level == 0:
# link buffers
if hasattr(self, "bias_atom_e"):
self.bias_atom_e = base_class.bias_atom_e
# the following will successfully link all the params except buffers, which need manually link.
for item in self._modules:
self._modules[item] = base_class._modules[item]
elif shared_level == 1:
# only not share the bias_atom_e
# the following will successfully link all the params except buffers, which need manually link.
for item in self._modules:
self._modules[item] = base_class._modules[item]
else:
raise NotImplementedError
[docs]
class GeneralFitting(Fitting):
"""Construct a general fitting net.
Parameters
----------
var_name : str
The atomic property to fit, 'energy', 'dipole', and 'polar'.
ntypes : int
Element count.
dim_descrpt : int
Embedding width per atom.
dim_out : int
The output dimension of the fitting net.
neuron : List[int]
Number of neurons in each hidden layers of the fitting net.
bias_atom_e : torch.Tensor, optional
Average enery per atom for each element.
resnet_dt : bool
Using time-step in the ResNet construction.
numb_fparam : int
Number of frame parameters.
numb_aparam : int
Number of atomic parameters.
activation_function : str
Activation function.
precision : str
Numerical precision.
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
rcond : float, optional
The condition number for the regression of atomic energy.
seed : int, optional
Random seed.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
trainable : Union[List[bool], bool]
If the parameters in the fitting net are trainable.
Now this only supports setting all the parameters in the fitting net at one state.
When in List[bool], the trainable will be True only if all the boolean parameters are True.
remove_vaccum_contribution: List[bool], optional
Remove vaccum contribution before the bias is added. The list assigned each
type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
"""
def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [128, 128, 128],
bias_atom_e: Optional[torch.Tensor] = None,
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
trainable: Union[bool, List[bool]] = True,
remove_vaccum_contribution: Optional[List[bool]] = None,
**kwargs,
):
super().__init__()
self.var_name = var_name
self.ntypes = ntypes
self.dim_descrpt = dim_descrpt
self.neuron = neuron
self.mixed_types = mixed_types
self.resnet_dt = resnet_dt
self.numb_fparam = numb_fparam
self.numb_aparam = numb_aparam
self.activation_function = activation_function
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.rcond = rcond
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
# need support for each layer settings
self.trainable = (
all(self.trainable) if isinstance(self.trainable, list) else self.trainable
)
self.remove_vaccum_contribution = remove_vaccum_contribution
net_dim_out = self._net_out_dim()
# init constants
if bias_atom_e is None:
bias_atom_e = np.zeros([self.ntypes, net_dim_out], dtype=np.float64)
bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device)
bias_atom_e = bias_atom_e.view([self.ntypes, net_dim_out])
if not self.mixed_types:
assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!"
self.register_buffer("bias_atom_e", bias_atom_e)
if self.numb_fparam > 0:
self.register_buffer(
"fparam_avg",
torch.zeros(self.numb_fparam, dtype=self.prec, device=device),
)
self.register_buffer(
"fparam_inv_std",
torch.ones(self.numb_fparam, dtype=self.prec, device=device),
)
else:
self.fparam_avg, self.fparam_inv_std = None, None
if self.numb_aparam > 0:
self.register_buffer(
"aparam_avg",
torch.zeros(self.numb_aparam, dtype=self.prec, device=device),
)
self.register_buffer(
"aparam_inv_std",
torch.ones(self.numb_aparam, dtype=self.prec, device=device),
)
else:
self.aparam_avg, self.aparam_inv_std = None, None
in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
self.old_impl = kwargs.get("old_impl", False)
if self.old_impl:
filter_layers = []
for type_i in range(self.ntypes if not self.mixed_types else 1):
bias_type = 0.0
one = ResidualDeep(
type_i,
self.dim_descrpt,
self.neuron,
bias_type,
resnet_dt=self.resnet_dt,
)
filter_layers.append(one)
self.filter_layers_old = torch.nn.ModuleList(filter_layers)
self.filter_layers = None
else:
self.filter_layers = NetworkCollection(
1 if not self.mixed_types else 0,
self.ntypes,
network_type="fitting_network",
networks=[
FittingNet(
in_dim,
net_dim_out,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
bias_out=True,
)
for ii in range(self.ntypes if not self.mixed_types else 1)
],
)
self.filter_layers_old = None
if seed is not None:
torch.manual_seed(seed)
# set trainable
for param in self.parameters():
param.requires_grad = self.trainable
[docs]
def reinit_exclude(
self,
exclude_types: List[int] = [],
):
self.exclude_types = exclude_types
self.emask = AtomExcludeMask(self.ntypes, self.exclude_types)
[docs]
def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
"@class": "Fitting",
"@version": 1,
"var_name": self.var_name,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"activation_function": self.activation_function,
"precision": self.precision,
"mixed_types": self.mixed_types,
"nets": self.filter_layers.serialize(),
"rcond": self.rcond,
"exclude_types": self.exclude_types,
"@variables": {
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"fparam_avg": to_numpy_array(self.fparam_avg),
"fparam_inv_std": to_numpy_array(self.fparam_inv_std),
"aparam_avg": to_numpy_array(self.aparam_avg),
"aparam_inv_std": to_numpy_array(self.aparam_inv_std),
},
# "tot_ener_zero": self.tot_ener_zero ,
# "trainable": self.trainable ,
# "atom_ener": self.atom_ener ,
# "layer_name": self.layer_name ,
# "use_aparam_as_mask": self.use_aparam_as_mask ,
# "spin": self.spin ,
## NOTICE: not supported by far
"tot_ener_zero": False,
"trainable": [self.trainable] * (len(self.neuron) + 1),
"layer_name": None,
"use_aparam_as_mask": False,
"spin": None,
}
@classmethod
[docs]
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
variables = data.pop("@variables")
nets = data.pop("nets")
obj = cls(**data)
for kk in variables.keys():
obj[kk] = to_torch_tensor(variables[kk])
obj.filter_layers = NetworkCollection.deserialize(nets)
return obj
[docs]
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.numb_fparam
[docs]
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.numb_aparam
# make jit happy
[docs]
exclude_types: List[int]
[docs]
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""
# make jit happy
sel_type: List[int] = []
for ii in range(self.ntypes):
if ii not in self.exclude_types:
sel_type.append(ii)
return sel_type
[docs]
def __setitem__(self, key, value):
if key in ["bias_atom_e"]:
value = value.view([self.ntypes, self._net_out_dim()])
self.bias_atom_e = value
elif key in ["fparam_avg"]:
self.fparam_avg = value
elif key in ["fparam_inv_std"]:
self.fparam_inv_std = value
elif key in ["aparam_avg"]:
self.aparam_avg = value
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
elif key in ["scale"]:
self.scale = value
else:
raise KeyError(key)
[docs]
def __getitem__(self, key):
if key in ["bias_atom_e"]:
return self.bias_atom_e
elif key in ["fparam_avg"]:
return self.fparam_avg
elif key in ["fparam_inv_std"]:
return self.fparam_inv_std
elif key in ["aparam_avg"]:
return self.aparam_avg
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
elif key in ["scale"]:
return self.scale
else:
raise KeyError(key)
@abstractmethod
[docs]
def _net_out_dim(self):
"""Set the FittingNet output dim."""
pass
[docs]
def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor:
return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1])
[docs]
def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor:
return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1])
[docs]
def _forward_common(
self,
descriptor: torch.Tensor,
atype: torch.Tensor,
gr: Optional[torch.Tensor] = None,
g2: Optional[torch.Tensor] = None,
h2: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
xx = descriptor
if self.remove_vaccum_contribution is not None:
# TODO: compute the input for vaccm when remove_vaccum_contribution is set
# Idealy, the input for vaccum should be computed;
# we consider it as always zero for convenience.
# Needs a compute_input_stats for vaccum passed from the
# descriptor.
xx_zeros = torch.zeros_like(xx)
else:
xx_zeros = None
nf, nloc, nd = xx.shape
net_dim_out = self._net_out_dim()
if nd != self.dim_descrpt:
raise ValueError(
"get an input descriptor of dim {nd},"
"which is not consistent with {self.dim_descrpt}."
)
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
assert fparam is not None, "fparam should not be None"
assert self.fparam_avg is not None
assert self.fparam_inv_std is not None
if fparam.shape[-1] != self.numb_fparam:
raise ValueError(
"get an input fparam of dim {fparam.shape[-1]}, ",
"which is not consistent with {self.numb_fparam}.",
)
fparam = fparam.view([nf, self.numb_fparam])
nb, _ = fparam.shape
t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb)
t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb)
fparam = (fparam - t_fparam_avg) * t_fparam_inv_std
fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1])
xx = torch.cat(
[xx, fparam],
dim=-1,
)
if xx_zeros is not None:
xx_zeros = torch.cat(
[xx_zeros, fparam],
dim=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
f"get an input aparam of dim {aparam.shape[-1]}, ",
f"which is not consistent with {self.numb_aparam}.",
)
aparam = aparam.view([nf, -1, self.numb_aparam])
nb, nloc, _ = aparam.shape
t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc)
t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc)
aparam = (aparam - t_aparam_avg) * t_aparam_inv_std
xx = torch.cat(
[xx, aparam],
dim=-1,
)
if xx_zeros is not None:
xx_zeros = torch.cat(
[xx_zeros, aparam],
dim=-1,
)
outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=descriptor.device,
) # jit assertion
if self.old_impl:
assert self.filter_layers_old is not None
assert xx_zeros is None
if self.mixed_types:
atom_property = self.filter_layers_old[0](xx) + self.bias_atom_e[atype]
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
for type_i, filter_layer in enumerate(self.filter_layers_old):
mask = atype == type_i
atom_property = filter_layer(xx)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask.unsqueeze(-1)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
if self.mixed_types:
atom_property = (
self.filter_layers.networks[0](xx) + self.bias_atom_e[atype]
)
if xx_zeros is not None:
atom_property -= self.filter_layers.networks[0](xx_zeros)
outs = (
outs + atom_property
) # Shape is [nframes, natoms[0], net_dim_out]
else:
for type_i, ll in enumerate(self.filter_layers.networks):
mask = (atype == type_i).unsqueeze(-1)
mask = torch.tile(mask, (1, 1, net_dim_out))
atom_property = ll(xx)
if xx_zeros is not None:
# must assert, otherwise jit is not happy
assert self.remove_vaccum_contribution is not None
if not (
len(self.remove_vaccum_contribution) > type_i
and not self.remove_vaccum_contribution[type_i]
):
atom_property -= ll(xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
outs = (
outs + atom_property
) # Shape is [nframes, natoms[0], net_dim_out]
# nf x nloc
mask = self.emask(atype)
# nf x nloc x nod
outs = outs * mask[:, :, None]
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}