Source code for deepmd.jax.jax2tf.tfmodel

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
)

import jax.experimental.jax2tf as jax2tf
import tensorflow as tf

from deepmd.dpmodel.output_def import (
    FittingOutputDef,
    ModelOutputDef,
    OutputVariableDef,
)
from deepmd.jax.env import (
    jnp,
)
from deepmd.utils.data_system import (
    DeepmdDataSystem,
)

[docs] OUTPUT_DEFS = { "energy": OutputVariableDef( "energy", shape=[1], reducible=True, r_differentiable=True, c_differentiable=True, ), "mask": OutputVariableDef( "mask", shape=[1], reducible=False, r_differentiable=False, c_differentiable=False, ), }
[docs] def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]: """Decode a list of bytes to a list of strings.""" return [x.decode() for x in list_of_bytes]
[docs] class TFModelWrapper(tf.Module): def __init__( self, model: str, ) -> None:
[docs] self.model = tf.saved_model.load(model)
[docs] self._call_lower = jax2tf.call_tf(self.model.call_lower)
[docs] self._call_lower_atomic_virial = jax2tf.call_tf( self.model.call_lower_atomic_virial )
[docs] self._call = jax2tf.call_tf(self.model.call)
[docs] self._call_atomic_virial = jax2tf.call_tf(self.model.call_atomic_virial)
[docs] self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist())
[docs] self.rcut = self.model.get_rcut().numpy().item()
[docs] self.dim_fparam = self.model.get_dim_fparam().numpy().item()
[docs] self.dim_aparam = self.model.get_dim_aparam().numpy().item()
[docs] self.sel_type = self.model.get_sel_type().numpy().tolist()
[docs] self._is_aparam_nall = self.model.is_aparam_nall().numpy().item()
[docs] self._model_output_type = decode_list_of_bytes( self.model.model_output_type().numpy().tolist() )
[docs] self._mixed_types = self.model.mixed_types().numpy().item()
if hasattr(self.model, "get_min_nbor_dist"): self.min_nbor_dist = self.model.get_min_nbor_dist().numpy().item() else: self.min_nbor_dist = None
[docs] self.sel = self.model.get_sel().numpy().tolist()
[docs] self.model_def_script = self.model.get_model_def_script().numpy().decode()
if hasattr(self.model, "has_default_fparam"): # No attrs before v3.1.2 self._has_default_fparam = self.model.has_default_fparam().numpy().item() else: self._has_default_fparam = False if hasattr(self.model, "get_default_fparam"): self.default_fparam = self.model.get_default_fparam().numpy().tolist() else: self.default_fparam = None
[docs] def __call__( self, coord: jnp.ndarray, atype: jnp.ndarray, box: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, ) -> Any: """Return model prediction. Parameters ---------- coord The coordinates of the atoms. shape: nf x (nloc x 3) atype The type of atoms. shape: nf x nloc box The simulation box. shape: nf x 9 fparam frame parameter. nf x ndf aparam atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. Returns ------- ret_dict The result dict of type dict[str,jnp.ndarray]. The keys are defined by the `ModelOutputDef`. """ return self.call(coord, atype, box, fparam, aparam, do_atomic_virial)
[docs] def call( self, coord: jnp.ndarray, atype: jnp.ndarray, box: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, ) -> dict[str, jnp.ndarray]: """Return model prediction. Parameters ---------- coord The coordinates of the atoms. shape: nf x (nloc x 3) atype The type of atoms. shape: nf x nloc box The simulation box. shape: nf x 9 fparam frame parameter. nf x ndf aparam atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. Returns ------- ret_dict The result dict of type dict[str,jnp.ndarray]. The keys are defined by the `ModelOutputDef`. """ if do_atomic_virial: call = self._call_atomic_virial else: call = self._call # Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor. if box is None: box = jnp.empty((coord.shape[0], 0, 0), dtype=jnp.float64) if fparam is None: fparam = jnp.empty( (coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 ) if aparam is None: aparam = jnp.empty( (coord.shape[0], coord.shape[1], self.get_dim_aparam()), dtype=jnp.float64, ) return call( coord, atype, box, fparam, aparam, )
[docs] def model_output_def(self) -> ModelOutputDef: return ModelOutputDef( FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) )
[docs] def call_lower( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, mapping: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, charge_spin: jnp.ndarray | None = None, ) -> dict[str, jnp.ndarray]: if do_atomic_virial: call_lower = self._call_lower_atomic_virial else: call_lower = self._call_lower # Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor. if fparam is None: fparam = jnp.empty( (extended_coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 ) if aparam is None: aparam = jnp.empty( (extended_coord.shape[0], nlist.shape[1], self.get_dim_aparam()), dtype=jnp.float64, ) return call_lower( extended_coord, extended_atype, nlist, mapping, fparam, aparam, )
[docs] def get_type_map(self) -> list[str]: """Get the type map.""" return self.type_map
[docs] def get_rcut(self) -> float: """Get the cut-off radius.""" return self.rcut
[docs] def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.dim_fparam
[docs] def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.dim_aparam
[docs] def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. Only atoms with selected atom types have atomic contribution to the result of the model. If returning an empty list, all atom types are selected. """ return self.sel_type
[docs] def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). If False, the shape is (nframes, nloc, ndim). """ return self._is_aparam_nall
[docs] def model_output_type(self) -> list[str]: """Get the output type for the model.""" return self._model_output_type
[docs] def serialize(self) -> dict: """Serialize the model. Returns ------- dict The serialized data """ raise NotImplementedError("Not implemented")
@classmethod
[docs] def deserialize(cls, data: dict) -> "TFModelWrapper": """Deserialize the model. Parameters ---------- data : dict The serialized data Returns ------- BaseModel The deserialized model """ raise NotImplementedError("Not implemented")
[docs] def get_model_def_script(self) -> str: """Get the model definition script.""" return self.model_def_script
[docs] def get_min_nbor_dist(self) -> float | None: """Get the minimum distance between two atoms.""" return self.min_nbor_dist
[docs] def get_nnei(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" return self.get_nsel()
[docs] def get_sel(self) -> list[int]: return self.sel
[docs] def get_nsel(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" return sum(self.sel)
[docs] def mixed_types(self) -> bool: return self._mixed_types
@classmethod
[docs] def update_sel( cls, train_data: DeepmdDataSystem, type_map: list[str] | None, local_jdata: dict, ) -> tuple[dict, float | None]: """Update the selection and perform neighbor statistics. Parameters ---------- train_data : DeepmdDataSystem data used to do neighbor statictics type_map : list[str], optional The name of each type of atoms local_jdata : dict The local data refer to the current class Returns ------- dict The updated local data float The minimum distance between two atoms """ raise NotImplementedError("Not implemented")
@classmethod
[docs] def get_model(cls, model_params: dict) -> "TFModelWrapper": """Get the model by the parameters. By default, all the parameters are directly passed to the constructor. If not, override this method. Parameters ---------- model_params : dict The model parameters Returns ------- BaseBaseModel The model """ raise NotImplementedError("Not implemented")
[docs] def has_default_fparam(self) -> bool: """Check whether the model has default frame parameters.""" return self._has_default_fparam
[docs] def get_default_fparam(self) -> list[float] | None: """Get the default frame parameters.""" return self.default_fparam