Source code for deepmd.pt.model.task.ener

# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (
    List,
    Optional,
    Tuple,
)

import numpy as np
import torch

from deepmd.dpmodel import (
    FittingOutputDef,
    OutputVariableDef,
    fitting_check_output,
)
from deepmd.pt.model.network.network import (
    ResidualDeep,
)
from deepmd.pt.model.task.fitting import (
    Fitting,
    GeneralFitting,
)
from deepmd.pt.model.task.invar_fitting import (
    InvarFitting,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.env import (
    DEFAULT_PRECISION,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

[docs] dtype = env.GLOBAL_PT_FLOAT_PRECISION
[docs] device = env.DEVICE
[docs] log = logging.getLogger(__name__)
@Fitting.register("ener")
[docs] class EnergyFittingNet(InvarFitting): def __init__( self, ntypes: int, dim_descrpt: int, neuron: List[int] = [128, 128, 128], bias_atom_e: Optional[torch.Tensor] = None, resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, **kwargs, ): super().__init__( "energy", ntypes, dim_descrpt, 1, neuron=neuron, bias_atom_e=bias_atom_e, resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, activation_function=activation_function, precision=precision, mixed_types=mixed_types, **kwargs, ) @classmethod
[docs] def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data)
[docs] def serialize(self) -> dict: """Serialize the fitting to dict.""" return { **super().serialize(), "type": "ener", }
# make jit happy with torch 2.0.0
[docs] exclude_types: List[int]
@Fitting.register("direct_force") @Fitting.register("direct_force_ener") @fitting_check_output
[docs] class EnergyFittingNetDirect(Fitting): def __init__( self, ntypes, dim_descrpt, neuron, bias_atom_e=None, out_dim=1, resnet_dt=True, use_tebd=True, return_energy=False, **kwargs, ): """Construct a fitting net for energy. Args: - ntypes: Element count. - embedding_width: Embedding width per atom. - neuron: Number of neurons in each hidden layers of the fitting net. - bias_atom_e: Average enery per atom for each element. - resnet_dt: Using time-step in the ResNet construction. """ super().__init__() self.ntypes = ntypes self.dim_descrpt = dim_descrpt self.use_tebd = use_tebd self.out_dim = out_dim if bias_atom_e is None: bias_atom_e = np.zeros([self.ntypes]) if not use_tebd: assert self.ntypes == len(bias_atom_e), "Element count mismatches!" bias_atom_e = torch.tensor(bias_atom_e, device=env.DEVICE) self.register_buffer("bias_atom_e", bias_atom_e) filter_layers_dipole = [] for type_i in range(self.ntypes): one = ResidualDeep( type_i, dim_descrpt, neuron, 0.0, out_dim=out_dim, resnet_dt=resnet_dt, ) filter_layers_dipole.append(one) self.filter_layers_dipole = torch.nn.ModuleList(filter_layers_dipole) self.return_energy = return_energy filter_layers = [] if self.return_energy: for type_i in range(self.ntypes): bias_type = 0.0 if self.use_tebd else bias_atom_e[type_i] one = ResidualDeep( type_i, dim_descrpt, neuron, bias_type, resnet_dt=resnet_dt ) filter_layers.append(one) self.filter_layers = torch.nn.ModuleList(filter_layers) if "seed" in kwargs: torch.manual_seed(kwargs["seed"])
[docs] def output_def(self): return FittingOutputDef( [ OutputVariableDef( "energy", [1], reduciable=True, r_differentiable=False, c_differentiable=False, ), OutputVariableDef( "dforce", [3], reduciable=False, r_differentiable=False, c_differentiable=False, ), ] )
[docs] def serialize(self) -> dict: raise NotImplementedError
[docs] def deserialize(cls) -> "EnergyFittingNetDirect": raise NotImplementedError
[docs] def forward( self, inputs: torch.Tensor, atype: torch.Tensor, gr: Optional[torch.Tensor] = None, g2: Optional[torch.Tensor] = None, h2: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, None]: """Based on embedding net output, alculate total energy. Args: - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. Returns ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ nframes, nloc, _ = inputs.size() if self.use_tebd: # if atype_tebd is not None: # inputs = torch.concat([inputs, atype_tebd], dim=-1) vec_out = self.filter_layers_dipole[0]( inputs ) # Shape is [nframes, nloc, m1] assert list(vec_out.size()) == [nframes, nloc, self.out_dim] # (nf x nloc) x 1 x od vec_out = vec_out.view(-1, 1, self.out_dim) assert gr is not None # (nf x nloc) x od x 3 gr = gr.view(-1, self.out_dim, 3) vec_out = ( torch.bmm(vec_out, gr).squeeze(-2).view(nframes, nloc, 3) ) # Shape is [nframes, nloc, 3] else: vec_out = torch.zeros_like(atype).unsqueeze(-1) # jit assertion for type_i, filter_layer in enumerate(self.filter_layers_dipole): mask = atype == type_i vec_out_type = filter_layer(inputs) # Shape is [nframes, nloc, m1] vec_out_type = vec_out_type * mask.unsqueeze(-1) vec_out = vec_out + vec_out_type # Shape is [nframes, natoms[0], 1] outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion if self.return_energy: if self.use_tebd: atom_energy = self.filter_layers[0](inputs) + self.bias_atom_e[ atype ].unsqueeze(-1) outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] else: for type_i, filter_layer in enumerate(self.filter_layers): mask = atype == type_i atom_energy = filter_layer(inputs) if not env.ENERGY_BIAS_TRAINABLE: atom_energy = atom_energy + self.bias_atom_e[type_i] atom_energy = atom_energy * mask.unsqueeze(-1) outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] return { "energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION), "dforce": vec_out, }