Source code for deepmd.tf.loss.loss

# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
    ABCMeta,
    abstractmethod,
)

import numpy as np

from deepmd.tf.env import (
    tf,
)
from deepmd.utils.data import (
    DataRequirementItem,
)


[docs] class Loss(metaclass=ABCMeta): """The abstract class for the loss function.""" @abstractmethod
[docs] def build( self, learning_rate: tf.Tensor, natoms: tf.Tensor, model_dict: dict[str, tf.Tensor], label_dict: dict[str, tf.Tensor], suffix: str, ) -> tuple[tf.Tensor, dict[str, tf.Tensor]]: """Build the loss function graph. Parameters ---------- learning_rate : tf.Tensor learning rate natoms : tf.Tensor number of atoms model_dict : dict[str, tf.Tensor] A dictionary that maps model keys to tensors label_dict : dict[str, tf.Tensor] A dictionary that maps label keys to tensors suffix : str suffix Returns ------- tf.Tensor the total squared loss dict[str, tf.Tensor] A dictionary that maps loss keys to more loss tensors """
@abstractmethod
[docs] def eval( self, sess: tf.Session, feed_dict: dict[tf.placeholder, tf.Tensor], natoms: tf.Tensor, ) -> dict: """Eval the loss function. Parameters ---------- sess : tf.Session TensorFlow session feed_dict : dict[tf.placeholder, tf.Tensor] A dictionary that maps graph elements to values natoms : tf.Tensor number of atoms Returns ------- dict A dictionary that maps keys to values. It should contain key `natoms` """
@staticmethod
[docs] def display_if_exist(loss: tf.Tensor, find_property: float) -> tf.Tensor: """Display NaN if labeled property is not found. Parameters ---------- loss : tf.Tensor the loss tensor find_property : float whether the property is found """ return tf.cond( tf.cast(find_property, tf.bool), lambda: loss, lambda: tf.cast(np.nan, dtype=loss.dtype), )
@property @abstractmethod
[docs] def label_requirement(self) -> list[DataRequirementItem]: """Return data label requirements needed for this loss calculation."""
[docs] def serialize(self, suffix: str = "") -> dict: """Serialize the loss module. Parameters ---------- suffix : str The suffix of the loss module Returns ------- dict The serialized loss module """ raise NotImplementedError
@classmethod
[docs] def deserialize(cls, data: dict, suffix: str = "") -> "Loss": """Deserialize the loss module. Parameters ---------- data : dict The serialized loss module suffix : str The suffix of the loss module Returns ------- Loss The deserialized loss module """ raise NotImplementedError
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, suffix: str = "", ) -> None: """No actual effect. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def suffix : str, optional The suffix of the scope """ pass