# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
)
from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)
if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)
from deepmd.utils.version import (
check_version_compatibility,
)
@InvarFitting.register("ener")
[docs]
class EnergyFittingNet(InvarFitting):
def __init__(
self,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
atom_ener: Optional[List[float]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
layer_name: Optional[List[Optional[str]]] = None,
use_aparam_as_mask: bool = False,
spin: Any = None,
mixed_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
super().__init__(
var_name="energy",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=1,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
atom_ener=atom_ener,
activation_function=activation_function,
precision=precision,
layer_name=layer_name,
use_aparam_as_mask=use_aparam_as_mask,
spin=spin,
mixed_types=mixed_types,
exclude_types=exclude_types,
)
@classmethod
[docs]
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
[docs]
def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "ener",
}