Source code for deepmd.pt.model.atomic_model.linear_atomic_model

# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
)

import torch

from deepmd.dpmodel import (
    FittingOutputDef,
    OutputVariableDef,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.nlist import (
    build_multiple_neighbor_list,
    get_multiple_nlist_key,
    nlist_distinguish_types,
)
from deepmd.utils.path import (
    DPPath,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

from .base_atomic_model import (
    BaseAtomicModel,
)
from .dp_atomic_model import (
    DPAtomicModel,
)
from .pairtab_atomic_model import (
    PairTabAtomicModel,
)


[docs] class LinearEnergyAtomicModel(BaseAtomicModel): """Linear model make linear combinations of several existing models. Parameters ---------- models : list[DPAtomicModel or PairTabAtomicModel] A list of models to be combined. PairTabAtomicModel must be used together with a DPAtomicModel. type_map : list[str] Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. """ def __init__( self, models: List[BaseAtomicModel], type_map: List[str], **kwargs, ): super().__init__(type_map, **kwargs) super().init_out_stat() self.models = torch.nn.ModuleList(models) sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] self.mapping_list = [] common_type_map = set(type_map) self.type_map = type_map for tpmp in sub_model_type_maps: if not common_type_map.issubset(set(tpmp)): err_msg.append( f"type_map {tpmp} is not a subset of type_map {type_map}" ) self.mapping_list.append(self.remap_atype(tpmp, self.type_map)) assert len(err_msg) == 0, "\n".join(err_msg) self.mixed_types_list = [model.mixed_types() for model in self.models] self.rcuts = torch.tensor( self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE ) self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
[docs] def mixed_types(self) -> bool: """If true, the model 1. assumes total number of atoms aligned across frames; 2. uses a neighbor list that does not distinguish different atomic types. If false, the model 1. assumes total number of atoms of each atom type aligned across frames; 2. uses a neighbor list that distinguishes different atomic types. """ return True
[docs] def get_out_bias(self) -> torch.Tensor: return self.out_bias
[docs] def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts())
[docs] def get_type_map(self) -> List[str]: """Get the type map.""" return self.type_map
[docs] def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" return [model.get_rcut() for model in self.models]
[docs] def get_sel(self) -> List[int]: return [max([model.get_nsel() for model in self.models])]
[docs] def get_model_nsels(self) -> List[int]: """Get the processed sels for each individual models. Not distinguishing types.""" return [model.get_nsel() for model in self.models]
[docs] def get_model_sels(self) -> List[List[int]]: """Get the sels for each individual models.""" return [model.get_sel() for model in self.models]
[docs] def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]: # sort the pair of rcut and sels in ascending order, first based on sel, then on rcut. zipped = torch.stack( [ self.rcuts, self.nsels, ], dim=0, ).T inner_sorting = torch.argsort(zipped[:, 1], dim=0) inner_sorted = zipped[inner_sorting] outer_sorting = torch.argsort(inner_sorted[:, 0], stable=True) outer_sorted = inner_sorted[outer_sorting] sorted_rcuts: List[float] = outer_sorted[:, 0].tolist() sorted_sels: List[int] = outer_sorted[:, 1].to(torch.int64).tolist() return sorted_rcuts, sorted_sels
[docs] def forward_atomic( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, comm_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Return atomic prediction. Parameters ---------- extended_coord coodinates in extended region, (nframes, nall * 3) extended_atype atomic type in extended region, (nframes, nall) nlist neighbor list, (nframes, nloc, nsel). mapping mapps the extended indices to local indices. fparam frame parameter. (nframes, ndf) aparam atomic parameter. (nframes, nloc, nda) Returns ------- result_dict the result dict, defined by the fitting net output def. """ nframes, nloc, nnei = nlist.shape if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) extended_coord = extended_coord.view(nframes, -1, 3) sorted_rcuts, sorted_sels = self._sort_rcuts_sels() nlists = build_multiple_neighbor_list( extended_coord, nlist, sorted_rcuts, sorted_sels, ) raw_nlists = [ nlists[get_multiple_nlist_key(rcut, sel)] for rcut, sel in zip(self.get_model_rcuts(), self.get_model_nsels()) ] nlists_ = [ nl if mt else nlist_distinguish_types(nl, extended_atype, sel) for mt, nl, sel in zip( self.mixed_types_list, raw_nlists, self.get_model_sels() ) ] ener_list = [] for i, model in enumerate(self.models): mapping = self.mapping_list[i] # apply bias to each individual model ener_list.append( model.forward_common_atomic( extended_coord, mapping[extended_atype], nlists_[i], mapping, fparam, aparam, )["energy"] ) weights = self._compute_weight(extended_coord, extended_atype, nlists_) fit_ret = { "energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0), } # (nframes, nloc, 1) return fit_ret
[docs] def apply_out_stat( self, ret: Dict[str, torch.Tensor], atype: torch.Tensor, ): """Apply the stat to each atomic output. The developer may override the method to define how the bias is applied to the atomic output of the model. Parameters ---------- ret The returned dict by the forward_atomic method atype The atom types. nf x nloc """ return ret
@staticmethod
[docs] def remap_atype(ori_map: List[str], new_map: List[str]) -> torch.Tensor: """ This method is used to map the atype from the common type_map to the original type_map of indivial AtomicModels. It creates a index mapping for the conversion. Parameters ---------- ori_map : List[str] The original type map of an AtomicModel. new_map : List[str] The common type map of the DPZBLLinearEnergyAtomicModel, created by the `get_type_map` method, must be a subset of the ori_map. Returns ------- torch.Tensor """ type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)} # this maps the atype in the new map to the original map mapping = torch.tensor( [type_2_idx[new_map[idx]] for idx in range(len(new_map))], device=env.DEVICE ) return mapping
[docs] def fitting_output_def(self) -> FittingOutputDef: return FittingOutputDef( [ OutputVariableDef( name="energy", shape=[1], reduciable=True, r_differentiable=True, c_differentiable=True, ) ] )
[docs] def serialize(self) -> dict: dd = super().serialize() dd.update( { "@class": "Model", "@version": 2, "type": "linear", "models": [model.serialize() for model in self.models], "type_map": self.type_map, } ) return dd
@classmethod
[docs] def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.get("@version", 2), 2, 1) data.pop("@class", None) data.pop("type", None) models = [ BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) for model in data["models"] ] data["models"] = models return super().deserialize(data)
[docs] def _compute_weight( self, extended_coord, extended_atype, nlists_ ) -> List[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) nframes, nloc, _ = nlists_[0].shape return [ torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) / nmodels for _ in range(nmodels) ]
[docs] def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" # tricky... return max([model.get_dim_fparam() for model in self.models])
[docs] def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return max([model.get_dim_aparam() for model in self.models])
[docs] def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. Only atoms with selected atom types have atomic contribution to the result of the model. If returning an empty list, all atom types are selected. """ if any(model.get_sel_type() == [] for model in self.models): return [] # join all the selected types # make torch.jit happy... return torch.unique( torch.cat( [ torch.as_tensor(model.get_sel_type(), dtype=torch.int32) for model in self.models ] ) ).tolist()
[docs] def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). If False, the shape is (nframes, nloc, ndim). """ return False
[docs] def compute_or_load_out_stat( self, merged: Union[Callable[[], List[dict]], List[dict]], stat_file_path: Optional[DPPath] = None, ): """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. Parameters ---------- merged : Union[Callable[[], List[dict]], List[dict]] - List[dict]: A list of data samples from various data systems. Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` originating from the `i`-th data system. - Callable[[], List[dict]]: A lazy function that returns data samples in the above format only when needed. Since the sampling process can be slow and memory-intensive, the lazy function helps by only sampling once. stat_file_path : Optional[DPPath] The path to the stat file. """ for md in self.models: md.compute_or_load_out_stat(merged, stat_file_path)
[docs] def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, ): """ Compute or load the statistics parameters of the model, such as mean and standard deviation of descriptors or the energy bias of the fitting net. When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), and saved in the `stat_file_path`(s). When `sampled` is not provided, it will check the existence of `stat_file_path`(s) and load the calculated statistics parameters. Parameters ---------- sampled_func The lazy sampled function to get data frames from different data systems. stat_file_path The dictionary of paths to the statistics files. """ for md in self.models: md.compute_or_load_stat(sampled_func, stat_file_path)
[docs] class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. Parameters ---------- dp_model The DPAtomicModel being combined. zbl_model The PairTable model being combined. sw_rmin The lower boundary of the interpolation between short-range tabulated interaction and DP. sw_rmax The upper boundary of the interpolation between short-range tabulated interaction and DP. type_map Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. smin_alpha The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. """ def __init__( self, dp_model: DPAtomicModel, zbl_model: PairTabAtomicModel, sw_rmin: float, sw_rmax: float, type_map: List[str], smin_alpha: Optional[float] = 0.1, **kwargs, ): models = [dp_model, zbl_model] kwargs["models"] = models kwargs["type_map"] = type_map super().__init__(**kwargs) self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax self.smin_alpha = smin_alpha # this is a placeholder being updated in _compute_weight, to handle Jit attribute init error. self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)
[docs] def serialize(self) -> dict: dd = super().serialize() dd.update( { "@class": "Model", "@version": 2, "type": "zbl", "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, "smin_alpha": self.smin_alpha, } ) return dd
@classmethod
[docs] def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 1) models = [ BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model) for model in data["models"] ] data["dp_model"], data["zbl_model"] = models[0], models[1] data.pop("@class", None) data.pop("type", None) return super().deserialize(data)
[docs] def _compute_weight( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlists_: List[torch.Tensor], ) -> List[torch.Tensor]: """ZBL weight. Returns ------- List[torch.Tensor] the atomic ZBL weight for interpolation. (nframes, nloc, 1) """ assert ( self.sw_rmax > self.sw_rmin ), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`." dp_nlist = nlists_[0] zbl_nlist = nlists_[1] zbl_nnei = zbl_nlist.shape[-1] dp_nnei = dp_nlist.shape[-1] # use the larger rr based on nlist nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist masked_nlist = torch.clamp(nlist_larger, 0) pairwise_rr = PairTabAtomicModel._get_pairwise_dist( extended_coord, masked_nlist ) numerator = torch.sum( pairwise_rr * torch.exp(-pairwise_rr / self.smin_alpha), dim=-1 ) # masked nnei will be zero, no need to handle denominator = torch.sum( torch.where( nlist_larger != -1, torch.exp(-pairwise_rr / self.smin_alpha), torch.zeros_like(nlist_larger), ), dim=-1, ) # handle masked nnei. sigma = numerator / torch.clamp(denominator, 1e-20) # nfrmes, nloc u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) coef = torch.zeros_like(u) left_mask = sigma < self.sw_rmin mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) right_mask = sigma >= self.sw_rmax coef[left_mask] = 1 smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 coef[mid_mask] = smooth[mid_mask] coef[right_mask] = 0 self.zbl_weight = coef # nframes, nloc return [1 - coef.unsqueeze(-1), coef.unsqueeze(-1)] # to match the model order.