Source code for deepmd.dpmodel.descriptor.make_base_descriptor

# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
    ABC,
    abstractmethod,
)
from typing import (
    Callable,
    List,
    Optional,
    Union,
)

from deepmd.common import (
    j_get_type,
)
from deepmd.utils.path import (
    DPPath,
)
from deepmd.utils.plugin import (
    PluginVariant,
    make_plugin_registry,
)


[docs] def make_base_descriptor( t_tensor, fwd_method_name: str = "forward", ): """Make the base class for the descriptor. Parameters ---------- t_tensor The type of the tensor. used in the type hint. fwd_method_name Name of the forward method. For dpmodels, it should be "call". For torch models, it should be "forward". """ class BD(ABC, PluginVariant, make_plugin_registry("descriptor")): """Base descriptor provides the interfaces of descriptor.""" def __new__(cls, *args, **kwargs): if cls is BD: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) @abstractmethod def get_rcut(self) -> float: """Returns the cut-off radius.""" pass @abstractmethod def get_sel(self) -> List[int]: """Returns the number of selected neighboring atoms for each type.""" pass def get_nsel(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" return sum(self.get_sel()) def get_nnei(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" return self.get_nsel() @abstractmethod def get_ntypes(self) -> int: """Returns the number of element types.""" pass @abstractmethod def get_dim_out(self) -> int: """Returns the output descriptor dimension.""" pass @abstractmethod def get_dim_emb(self) -> int: """Returns the embedding dimension of g2.""" pass @abstractmethod def mixed_types(self) -> bool: """Returns if the descriptor requires a neighbor list that distinguish different atomic types or not. """ pass @abstractmethod def share_params(self, base_class, shared_level, resume=False): """ Share the parameters of self to the base_class with shared_level during multitask training. If not start from checkpoint (resume is False), some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. """ pass def compute_input_stats( self, merged: Union[Callable[[], List[dict]], List[dict]], path: Optional[DPPath] = None, ): """Update mean and stddev for descriptor elements.""" raise NotImplementedError @abstractmethod def fwd( self, extended_coord, extended_atype, nlist, mapping: Optional[t_tensor] = None, ): """Calculate descriptor.""" pass @abstractmethod def serialize(self) -> dict: """Serialize the obj to dict.""" pass @classmethod def deserialize(cls, data: dict) -> "BD": """Deserialize the model. Parameters ---------- data : dict The serialized data Returns ------- BD The deserialized descriptor """ if cls is BD: return BD.get_class_by_type(data["type"]).deserialize(data) raise NotImplementedError(f"Not implemented in class {cls.__name__}") @classmethod @abstractmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ # call subprocess cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__)) return cls.update_sel(global_jdata, local_jdata) setattr(BD, fwd_method_name, BD.fwd) delattr(BD, "fwd") return BD