Source code for deepmd.tf.fit.fitting

# SPDX-License-Identifier: LGPL-3.0-or-later
import re
from abc import (
    abstractmethod,
)
from typing import (
    List,
    Optional,
)

from deepmd.common import (
    j_get_type,
)
from deepmd.dpmodel.utils.network import (
    FittingNet,
    NetworkCollection,
)
from deepmd.tf.env import (
    FITTING_NET_PATTERN,
    tf,
)
from deepmd.tf.loss.loss import (
    Loss,
)
from deepmd.tf.utils import (
    PluginVariant,
)
from deepmd.utils.plugin import (
    make_plugin_registry,
)


[docs] class Fitting(PluginVariant, make_plugin_registry("fitting")): def __new__(cls, *args, **kwargs): if cls is Fitting: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) @property
[docs] def precision(self) -> tf.DType: """Precision of fitting network.""" return self.fitting_precision
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, suffix: str = "", ) -> None: """Init the fitting net variables with the given dict. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def suffix : str suffix to name scope Notes ----- This method is called by others when the fitting supported initialization from the given variables. """ raise NotImplementedError( f"Fitting {type(self).__name__} doesn't support initialization from the given variables!" )
@abstractmethod
[docs] def get_loss(self, loss: dict, lr) -> Loss: """Get the loss function. Parameters ---------- loss : dict the loss dict lr : LearningRateExp the learning rate Returns ------- Loss the loss function """
@classmethod
[docs] def deserialize(cls, data: dict, suffix: str = "") -> "Fitting": """Deserialize the fitting. There is no suffix in a native DP model, but it is important for the TF backend. Parameters ---------- data : dict The serialized data suffix : str, optional Name suffix to identify this fitting Returns ------- Fitting The deserialized fitting """ if cls is Fitting: return Fitting.get_class_by_type( j_get_type(data, cls.__name__) ).deserialize(data, suffix=suffix) raise NotImplementedError(f"Not implemented in class {cls.__name__}")
[docs] def serialize(self, suffix: str = "") -> dict: """Serialize the fitting. There is no suffix in a native DP model, but it is important for the TF backend. Returns ------- dict The serialized data suffix : str, optional Name suffix to identify this fitting """ raise NotImplementedError(f"Not implemented in class {self.__name__}")
[docs] def serialize_network( self, ntypes: int, ndim: int, in_dim: int, neuron: List[int], activation_function: str, resnet_dt: bool, variables: dict, out_dim: Optional[int] = 1, suffix: str = "", ) -> dict: """Serialize network. Parameters ---------- ntypes : int The number of types ndim : int The dimension of elements in_dim : int The input dimension neuron : List[int] The neuron list activation_function : str The activation function resnet_dt : bool Whether to use resnet variables : dict The input variables suffix : str, optional The suffix of the scope out_dim : int, optional The output dimension Returns ------- dict The converted network data """ fittings = NetworkCollection( ntypes=ntypes, ndim=ndim, network_type="fitting_network", ) if suffix != "": fitting_net_pattern = ( FITTING_NET_PATTERN.replace("/(idt)", suffix + "/(idt)") .replace("/(bias)", suffix + "/(bias)") .replace("/(matrix)", suffix + "/(matrix)") ) else: fitting_net_pattern = FITTING_NET_PATTERN for key, value in variables.items(): m = re.search(fitting_net_pattern, key) m = [mm for mm in m.groups() if mm is not None] layer_idx = int(m[0]) if m[0] != "final" else len(neuron) weight_name = m[-1] if ndim == 0: network_idx = () elif ndim == 1: network_idx = (int(m[1]),) else: raise ValueError(f"Invalid ndim: {ndim}") if fittings[network_idx] is None: # initialize the network if it is not initialized fittings[network_idx] = FittingNet( in_dim=in_dim, out_dim=out_dim, neuron=neuron, activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, bias_out=True, ) assert fittings[network_idx] is not None if weight_name == "idt": value = value.ravel() fittings[network_idx][layer_idx][weight_name] = value return fittings.serialize()
@classmethod
[docs] def deserialize_network(cls, data: dict, suffix: str = "") -> dict: """Deserialize network. Parameters ---------- data : dict The input network data suffix : str, optional The suffix of the scope Returns ------- variables : dict The input variables """ fitting_net_variables = {} fittings = NetworkCollection.deserialize(data) for ii in range(fittings.ntypes**fittings.ndim): net_idx = [] rest_ii = ii for _ in range(fittings.ndim): net_idx.append(rest_ii % fittings.ntypes) rest_ii //= fittings.ntypes net_idx = tuple(net_idx) if fittings.ndim == 0: key = "" elif fittings.ndim == 1: key = "_type_" + str(net_idx[0]) else: raise ValueError(f"Invalid ndim: {fittings.ndim}") network = fittings[net_idx] assert network is not None for layer_idx, layer in enumerate(network.layers): if layer_idx == len(network.layers) - 1: layer_name = "final_layer" else: layer_name = f"layer_{layer_idx}" fitting_net_variables[f"{layer_name}{key}{suffix}/matrix"] = layer.w fitting_net_variables[f"{layer_name}{key}{suffix}/bias"] = layer.b if layer.idt is not None: fitting_net_variables[f"{layer_name}{key}{suffix}/idt"] = ( layer.idt.reshape(1, -1) ) else: # prevent keyError fitting_net_variables[f"{layer_name}{key}{suffix}/idt"] = 0.0 return fitting_net_variables