Source code for deepmd.tf.model.frozen

# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import tempfile
from enum import (
    Enum,
)
from typing import (
    Optional,
    Union,
)

from deepmd.entrypoints.convert_backend import (
    convert_backend,
)
from deepmd.infer.deep_pot import (
    DeepPot,
)
from deepmd.tf.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    MODEL_VERSION,
    tf,
)
from deepmd.tf.fit.fitting import (
    Fitting,
)
from deepmd.tf.infer import (
    DeepPotential,
)
from deepmd.tf.loss.loss import (
    Loss,
)
from deepmd.tf.utils.graph import (
    get_tensor_by_name_from_graph,
    load_graph_def,
)

from .model import (
    Model,
)


@Model.register("frozen")
[docs] class FrozenModel(Model): """Load model from a frozen model, which cannot be trained. Parameters ---------- model_file : str The path to the frozen model """ def __init__(self, model_file: str, **kwargs): super().__init__(**kwargs) self.model_file = model_file if not model_file.endswith(".pb"): # try to convert from other formats with tempfile.NamedTemporaryFile( suffix=".pb", dir=os.curdir, delete=False ) as f: convert_backend(INPUT=model_file, OUTPUT=f.name) self.model_file = f.name self.model = DeepPotential(self.model_file) if isinstance(self.model, DeepPot): self.model_type = "ener" else: raise NotImplementedError( "This model type has not been implemented. " "Contribution is welcome!" )
[docs] def build( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model: Optional[str] = None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[Union[bool, Enum]] = None, ) -> dict: """Build the model. Parameters ---------- coord_ : tf.Tensor The coordinates of atoms atype_ : tf.Tensor The atom types of atoms natoms : tf.Tensor The number of atoms box : tf.Tensor The box vectors mesh : tf.Tensor The mesh vectors input_dict : dict The input dict frz_model : str, optional The path to the frozen model ckpt_meta : str, optional The path prefix of the checkpoint and meta files suffix : str, optional The suffix of the scope reuse : bool or tf.AUTO_REUSE, optional Whether to reuse the variables Returns ------- dict The output dict """ # reset the model to import to the correct graph extra_feed_dict = {} if input_dict is not None: if "fparam" in input_dict: extra_feed_dict["fparam"] = input_dict["fparam"] if "aparam" in input_dict: extra_feed_dict["aparam"] = input_dict["aparam"] input_map = self.get_feed_dict( coord_, atype_, natoms, box, mesh, **extra_feed_dict ) self.model = DeepPotential( self.model_file, default_tf_graph=True, load_prefix="load" + suffix, input_map=input_map, ) with tf.variable_scope("model_attr" + suffix, reuse=reuse): t_tmap = tf.constant( " ".join(self.get_type_map()), name="tmap", dtype=tf.string ) t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): t_ntypes = tf.constant(self.get_ntypes(), name="ntypes", dtype=tf.int32) t_rcut = tf.constant( self.get_rcut(), name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION ) with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): t_dfparam = tf.constant( self.model.get_dim_fparam(), name="dfparam", dtype=tf.int32 ) t_daparam = tf.constant( self.model.get_dim_aparam(), name="daparam", dtype=tf.int32 ) if self.model_type == "ener": return { # must visit the backend class "energy": tf.identity( self.model.deep_eval.output_tensors["energy_redu"], name="o_energy" + suffix, ), "force": tf.identity( self.model.deep_eval.output_tensors["energy_derv_r"], name="o_force" + suffix, ), "virial": tf.identity( self.model.deep_eval.output_tensors["energy_derv_c_redu"], name="o_virial" + suffix, ), "atom_ener": tf.identity( self.model.deep_eval.output_tensors["energy"], name="o_atom_energy" + suffix, ), "atom_virial": tf.identity( self.model.deep_eval.output_tensors["energy_derv_c"], name="o_atom_virial" + suffix, ), "coord": coord_, "atype": atype_, } else: raise NotImplementedError( f"Model type {self.model_type} has not been implemented. " "Contribution is welcome!" )
[docs] def get_fitting(self) -> Union[Fitting, dict]: """Get the fitting(s).""" return {}
[docs] def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]: """Get the loss function(s).""" # loss should be never used for a frozen model return
[docs] def get_rcut(self): return self.model.get_rcut()
[docs] def get_ntypes(self) -> int: return self.model.get_ntypes()
[docs] def data_stat(self, data): pass
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """ pass
[docs] def enable_compression(self, suffix: str = "") -> None: """Enable compression. Parameters ---------- suffix : str suffix to name scope """ pass
[docs] def get_type_map(self) -> list: """Get the type map.""" return self.model.get_type_map()
@classmethod
[docs] def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ # we don't know how to compress it, so no neighbor statistics here return local_jdata
[docs] def serialize(self, suffix: str = "") -> dict: # try to recover the original model # the current graph contains a prefix "load", # so it cannot used to recover the original model graph, graph_def = load_graph_def(self.model_file) t_jdata = get_tensor_by_name_from_graph(graph, "train_attr/training_script") jdata = json.loads(t_jdata) model = Model(**jdata["model"]) # important! must be called before serialize model.init_variables(graph=graph, graph_def=graph_def) return model.serialize()
@classmethod
[docs] def deserialize(cls, data: dict, suffix: str = ""): raise RuntimeError("Should not touch here.")