Source code for deepmd.model.tensor

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
    Optional,
    Union,
)

from deepmd.env import (
    MODEL_VERSION,
    tf,
)
from deepmd.utils.type_embed import (
    TypeEmbedNet,
)

from .model import (
    StandardModel,
)
from .model_stat import (
    make_stat_input,
    merge_sys_stat,
)


[docs]class TensorModel(StandardModel): """Tensor model. Parameters ---------- tensor_name Name of the tensor. descriptor Descriptor fitting_net Fitting net type_embedding Type embedding net type_map Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. data_stat_nbatch Number of frames used for data statistic data_stat_protect Protect parameter for atomic energy regression """ def __init__( self, tensor_name: str, descriptor: dict, fitting_net: dict, type_embedding: Optional[Union[dict, TypeEmbedNet]] = None, type_map: Optional[List[str]] = None, data_stat_nbatch: int = 10, data_stat_protect: float = 1e-2, **kwargs, ) -> None: """Constructor.""" super().__init__( descriptor=descriptor, fitting_net=fitting_net, type_embedding=type_embedding, type_map=type_map, data_stat_nbatch=data_stat_nbatch, data_stat_protect=data_stat_protect, **kwargs, ) self.model_type = tensor_name
[docs] def get_rcut(self): return self.rcut
[docs] def get_ntypes(self): return self.ntypes
[docs] def get_type_map(self): return self.type_map
[docs] def get_sel_type(self): return self.fitting.get_sel_type()
[docs] def get_out_size(self): return self.fitting.get_out_size()
[docs] def data_stat(self, data): all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False) m_all_stat = merge_sys_stat(all_stat) self._compute_input_stat(m_all_stat, protection=self.data_stat_protect) self._compute_output_stat(all_stat)
def _compute_input_stat(self, all_stat, protection=1e-2): self.descrpt.compute_input_stats( all_stat["coord"], all_stat["box"], all_stat["type"], all_stat["natoms_vec"], all_stat["default_mesh"], all_stat, ) if hasattr(self.fitting, "compute_input_stats"): self.fitting.compute_input_stats(all_stat, protection=protection) def _compute_output_stat(self, all_stat): if hasattr(self.fitting, "compute_output_stats"): self.fitting.compute_output_stats(all_stat)
[docs] def build( self, coord_, atype_, natoms, box, mesh, input_dict, frz_model=None, ckpt_meta: Optional[str] = None, suffix="", reuse=None, ): if input_dict is None: input_dict = {} with tf.variable_scope("model_attr" + suffix, reuse=reuse): t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) t_st = tf.constant(self.get_sel_type(), name="sel_type", dtype=tf.int32) t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) t_od = tf.constant(self.get_out_size(), name="output_dim", dtype=tf.int32) natomsel = sum(natoms[2 + type_i] for type_i in self.get_sel_type()) nout = self.get_out_size() coord = tf.reshape(coord_, [-1, natoms[1] * 3]) atype = tf.reshape(atype_, [-1, natoms[1]]) input_dict["nframes"] = tf.shape(coord)[0] # type embedding if any if self.typeebd is not None: type_embedding = self.build_type_embedding( self.ntypes, reuse=reuse, suffix=suffix, ckpt_meta=ckpt_meta, frz_model=frz_model, ) input_dict["type_embedding"] = type_embedding input_dict["atype"] = atype_ dout = self.build_descrpt( coord, atype, natoms, box, mesh, input_dict, frz_model=frz_model, ckpt_meta=ckpt_meta, suffix=suffix, reuse=reuse, ) rot_mat = self.descrpt.get_rot_mat() rot_mat = tf.identity(rot_mat, name="o_rot_mat" + suffix) output = self.fitting.build( dout, rot_mat, natoms, input_dict, reuse=reuse, suffix=suffix ) framesize = nout if "global" in self.model_type else natomsel * nout output = tf.reshape( output, [-1, framesize], name="o_" + self.model_type + suffix ) model_dict = {self.model_type: output} if "global" not in self.model_type: gname = "global_" + self.model_type atom_out = tf.reshape(output, [-1, natomsel, nout]) global_out = tf.reduce_sum(atom_out, axis=1) global_out = tf.reshape(global_out, [-1, nout], name="o_" + gname + suffix) out_cpnts = tf.split(atom_out, nout, axis=-1) force_cpnts = [] virial_cpnts = [] atom_virial_cpnts = [] for out_i in out_cpnts: force_i, virial_i, atom_virial_i = self.descrpt.prod_force_virial( out_i, natoms ) force_cpnts.append(tf.reshape(force_i, [-1, 3 * natoms[1]])) virial_cpnts.append(tf.reshape(virial_i, [-1, 9])) atom_virial_cpnts.append(tf.reshape(atom_virial_i, [-1, 9 * natoms[1]])) # [nframe x nout x (natom x 3)] force = tf.concat(force_cpnts, axis=1, name="o_force" + suffix) # [nframe x nout x 9] virial = tf.concat(virial_cpnts, axis=1, name="o_virial" + suffix) # [nframe x nout x (natom x 9)] atom_virial = tf.concat( atom_virial_cpnts, axis=1, name="o_atom_virial" + suffix ) model_dict[gname] = global_out model_dict["force"] = force model_dict["virial"] = virial model_dict["atom_virial"] = atom_virial return model_dict
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """ if model_type == "original_model": self.descrpt.init_variables(graph, graph_def, suffix=suffix) self.fitting.init_variables(graph, graph_def, suffix=suffix) tf.constant("original_model", name="model_type", dtype=tf.string) elif model_type == "compressed_model": self.fitting.init_variables(graph, graph_def, suffix=suffix) tf.constant("compressed_model", name="model_type", dtype=tf.string) else: raise RuntimeError("Unknown model type %s" % model_type)
[docs]class WFCModel(TensorModel): def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "wfc", *args, **kwargs)
[docs]class DipoleModel(TensorModel): def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "dipole", *args, **kwargs)
[docs]class PolarModel(TensorModel): def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "polar", *args, **kwargs)
[docs]class GlobalPolarModel(TensorModel): def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "global_polar", *args, **kwargs)