Source code for deepmd.tf.descriptor.se

# SPDX-License-Identifier: LGPL-3.0-or-later
import re
from typing import (
    List,
    Set,
    Tuple,
)

from deepmd.dpmodel.utils.network import (
    EmbeddingNet,
    NetworkCollection,
)
from deepmd.tf.env import (
    EMBEDDING_NET_PATTERN,
    tf,
)
from deepmd.tf.utils.graph import (
    get_embedding_net_variables_from_graph_def,
    get_tensor_by_name_from_graph,
)
from deepmd.tf.utils.update_sel import (
    UpdateSel,
)

from .descriptor import (
    Descriptor,
)


[docs] class DescrptSe(Descriptor): """A base class for smooth version of descriptors. Notes ----- All of these descriptors have an environmental matrix and an embedding network (:meth:`deepmd.tf.utils.network.embedding_net`), so they can share some similiar methods without defining them twice. Attributes ---------- embedding_net_variables : dict initial embedding network variables descrpt_reshape : tf.Tensor the reshaped descriptor descrpt_deriv : tf.Tensor the descriptor derivative rij : tf.Tensor distances between two atoms nlist : tf.Tensor the neighbor list """
[docs] def _identity_tensors(self, suffix: str = "") -> None: """Identify tensors which are expected to be stored and restored. Notes ----- These tensors will be indentitied: self.descrpt_reshape : o_rmat self.descrpt_deriv : o_rmat_deriv self.rij : o_rij self.nlist : o_nlist Thus, this method should be called during building the descriptor and after these tensors are initialized. Parameters ---------- suffix : str The suffix of the scope """ self.descrpt_reshape = tf.identity(self.descrpt_reshape, name="o_rmat" + suffix) self.descrpt_deriv = tf.identity( self.descrpt_deriv, name="o_rmat_deriv" + suffix ) self.rij = tf.identity(self.rij, name="o_rij" + suffix) self.nlist = tf.identity(self.nlist, name="o_nlist" + suffix)
[docs] def get_tensor_names(self, suffix: str = "") -> Tuple[str]: """Get names of tensors. Parameters ---------- suffix : str The suffix of the scope Returns ------- Tuple[str] Names of tensors """ return ( f"o_rmat{suffix}:0", f"o_rmat_deriv{suffix}:0", f"o_rij{suffix}:0", f"o_nlist{suffix}:0", )
[docs] def pass_tensors_from_frz_model( self, descrpt_reshape: tf.Tensor, descrpt_deriv: tf.Tensor, rij: tf.Tensor, nlist: tf.Tensor, ): """Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def. Parameters ---------- descrpt_reshape The passed descrpt_reshape tensor descrpt_deriv The passed descrpt_deriv tensor rij The passed rij tensor nlist The passed nlist tensor """ self.rij = rij self.nlist = nlist self.descrpt_deriv = descrpt_deriv self.descrpt_reshape = descrpt_reshape
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, suffix: str = "", ) -> None: """Init the embedding net variables with the given dict. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def suffix : str, optional The suffix of the scope """ self.embedding_net_variables = get_embedding_net_variables_from_graph_def( graph_def, suffix=suffix ) self.davg = get_tensor_by_name_from_graph(graph, f"descrpt_attr{suffix}/t_avg") self.dstd = get_tensor_by_name_from_graph(graph, f"descrpt_attr{suffix}/t_std")
@property
[docs] def precision(self) -> tf.DType: """Precision of filter network.""" return self.filter_precision
@classmethod
[docs] def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ # default behavior is to update sel which is a list local_jdata_cpy = local_jdata.copy() return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
[docs] def serialize_network( self, ntypes: int, ndim: int, in_dim: int, neuron: List[int], activation_function: str, resnet_dt: bool, variables: dict, excluded_types: Set[Tuple[int, int]] = set(), suffix: str = "", ) -> dict: """Serialize network. Parameters ---------- ntypes : int The number of types ndim : int The dimension of elements in_dim : int The input dimension neuron : List[int] The neuron list activation_function : str The activation function resnet_dt : bool Whether to use resnet variables : dict The input variables excluded_types : Set[Tuple[int, int]], optional The excluded types suffix : str, optional The suffix of the scope Returns ------- dict The converted network data """ embeddings = NetworkCollection( ntypes=ntypes, ndim=ndim, network_type="embedding_network", ) if ndim == 2: for type_i, type_j in excluded_types: # initialize an empty network for the excluded types embeddings[(type_i, type_j)] = EmbeddingNet( in_dim=in_dim, neuron=neuron, activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, ) embeddings[(type_j, type_i)] = EmbeddingNet( in_dim=in_dim, neuron=neuron, activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, ) embeddings[(type_i, type_j)].clear() embeddings[(type_j, type_i)].clear() if suffix != "": embedding_net_pattern = ( EMBEDDING_NET_PATTERN.replace("/(idt)", suffix + "/(idt)") .replace("/(bias)", suffix + "/(bias)") .replace("/(matrix)", suffix + "/(matrix)") ) else: embedding_net_pattern = EMBEDDING_NET_PATTERN for key, value in variables.items(): m = re.search(embedding_net_pattern, key) m = [mm for mm in m.groups() if mm is not None] typei = m[0] typej = "_".join(m[3:]) if len(m[3:]) else "all" layer_idx = int(m[2]) - 1 weight_name = m[1] if ndim == 0: network_idx = () elif ndim == 1: network_idx = (int(typej),) elif ndim == 2: network_idx = (int(typei), int(typej)) else: raise ValueError(f"Invalid ndim: {ndim}") if embeddings[network_idx] is None: # initialize the network if it is not initialized embeddings[network_idx] = EmbeddingNet( in_dim=in_dim, neuron=neuron, activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, ) assert embeddings[network_idx] is not None if weight_name == "idt": value = value.ravel() embeddings[network_idx][layer_idx][weight_name] = value return embeddings.serialize()
@classmethod
[docs] def deserialize_network(cls, data: dict, suffix: str = "") -> dict: """Deserialize network. Parameters ---------- data : dict The input network data suffix : str, optional The suffix of the scope Returns ------- variables : dict The input variables """ embedding_net_variables = {} embeddings = NetworkCollection.deserialize(data) for ii in range(embeddings.ntypes**embeddings.ndim): net_idx = [] rest_ii = ii for _ in range(embeddings.ndim): net_idx.append(rest_ii % embeddings.ntypes) rest_ii //= embeddings.ntypes net_idx = tuple(net_idx) if embeddings.ndim == 0: key0 = "all" key1 = "" elif embeddings.ndim == 1: key0 = "all" key1 = f"_{ii}" elif embeddings.ndim == 2: key0 = f"{net_idx[0]}" key1 = f"_{net_idx[1]}" else: raise ValueError(f"Invalid ndim: {embeddings.ndim}") network = embeddings[net_idx] assert network is not None for layer_idx, layer in enumerate(network.layers): embedding_net_variables[ f"filter_type_{key0}{suffix}/matrix_{layer_idx + 1}{key1}" ] = layer.w embedding_net_variables[ f"filter_type_{key0}{suffix}/bias_{layer_idx + 1}{key1}" ] = layer.b if layer.idt is not None: embedding_net_variables[ f"filter_type_{key0}{suffix}/idt_{layer_idx + 1}{key1}" ] = layer.idt.reshape(1, -1) else: # prevent keyError embedding_net_variables[ f"filter_type_{key0}{suffix}/idt_{layer_idx + 1}{key1}" ] = 0.0 return embedding_net_variables