Source code for deepmd.common

# SPDX-License-Identifier: LGPL-3.0-or-later
"""Collection of functions and classes used throughout the whole package."""

import warnings
from functools import (
    wraps,
)
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Union,
)

import tensorflow
from tensorflow.python.framework import (
    tensor_util,
)

from deepmd.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    op_module,
    tf,
)
from deepmd_utils.common import (
    add_data_requirement,
    data_requirement,
    expand_sys_str,
    get_np_precision,
    j_loader,
    j_must_have,
    make_default_mesh,
    select_idx_map,
)

if TYPE_CHECKING:
    from deepmd_utils.common import (
        _ACTIVATION,
        _PRECISION,
    )

__all__ = [
    # from deepmd_utils.common
    "data_requirement",
    "add_data_requirement",
    "select_idx_map",
    "make_default_mesh",
    "j_must_have",
    "j_loader",
    "expand_sys_str",
    "get_np_precision",
    # from self
    "PRECISION_DICT",
    "gelu",
    "gelu_tf",
    "ACTIVATION_FN_DICT",
    "get_activation_func",
    "get_precision",
    "safe_cast_tensor",
    "cast_precision",
    "clear_session",
]

# define constants
PRECISION_DICT = {
    "default": GLOBAL_TF_FLOAT_PRECISION,
    "float16": tf.float16,
    "float32": tf.float32,
    "float64": tf.float64,
    "bfloat16": tf.bfloat16,
}


[docs]def gelu(x: tf.Tensor) -> tf.Tensor: """Gaussian Error Linear Unit. This is a smoother version of the RELU, implemented by custom operator. Parameters ---------- x : tf.Tensor float Tensor to perform activation Returns ------- tf.Tensor `x` with the GELU activation applied References ---------- Original paper https://arxiv.org/abs/1606.08415 """ return op_module.gelu_custom(x)
[docs]def gelu_tf(x: tf.Tensor) -> tf.Tensor: """Gaussian Error Linear Unit. This is a smoother version of the RELU, implemented by TF. Parameters ---------- x : tf.Tensor float Tensor to perform activation Returns ------- tf.Tensor `x` with the GELU activation applied References ---------- Original paper https://arxiv.org/abs/1606.08415 """ def gelu_wrapper(x): try: return tensorflow.nn.gelu(x, approximate=True) except AttributeError: warnings.warn( "TensorFlow does not provide an implementation of gelu, please upgrade your TensorFlow version. Fallback to the custom gelu operator." ) return op_module.gelu_custom(x) return (lambda x: gelu_wrapper(x))(x)
ACTIVATION_FN_DICT = { "relu": tf.nn.relu, "relu6": tf.nn.relu6, "softplus": tf.nn.softplus, "sigmoid": tf.sigmoid, "tanh": tf.nn.tanh, "gelu": gelu, "gelu_tf": gelu_tf, "None": None, "none": None, }
[docs]def get_activation_func( activation_fn: Union["_ACTIVATION", None], ) -> Union[Callable[[tf.Tensor], tf.Tensor], None]: """Get activation function callable based on string name. Parameters ---------- activation_fn : _ACTIVATION one of the defined activation functions Returns ------- Callable[[tf.Tensor], tf.Tensor] correspondingg TF callable Raises ------ RuntimeError if unknown activation function is specified """ if activation_fn is None: return None if activation_fn not in ACTIVATION_FN_DICT: raise RuntimeError(f"{activation_fn} is not a valid activation function") return ACTIVATION_FN_DICT[activation_fn]
[docs]def get_precision(precision: "_PRECISION") -> Any: """Convert str to TF DType constant. Parameters ---------- precision : _PRECISION one of the allowed precisions Returns ------- tf.python.framework.dtypes.DType appropriate TF constant Raises ------ RuntimeError if supplied precision string does not have acorresponding TF constant """ if precision not in PRECISION_DICT: raise RuntimeError(f"{precision} is not a valid precision") return PRECISION_DICT[precision]
[docs]def safe_cast_tensor( input: tf.Tensor, from_precision: tf.DType, to_precision: tf.DType ) -> tf.Tensor: """Convert a Tensor from a precision to another precision. If input is not a Tensor or without the specific precision, the method will not cast it. Parameters ---------- input : tf.Tensor input tensor from_precision : tf.DType Tensor data type that is casted from to_precision : tf.DType Tensor data type that casts to Returns ------- tf.Tensor casted Tensor """ if tensor_util.is_tensor(input) and input.dtype == from_precision: return tf.cast(input, to_precision) return input
[docs]def cast_precision(func: Callable) -> Callable: """A decorator that casts and casts back the input and output tensor of a method. The decorator should be used in a classmethod. The decorator will do the following thing: (1) It casts input Tensors from `GLOBAL_TF_FLOAT_PRECISION` to precision defined by property `precision`. (2) It casts output Tensors from `precision` to `GLOBAL_TF_FLOAT_PRECISION`. (3) It checks inputs and outputs and only casts when input or output is a Tensor and its dtype matches `GLOBAL_TF_FLOAT_PRECISION` and `precision`, respectively. If it does not match (e.g. it is an integer), the decorator will do nothing on it. Returns ------- Callable a decorator that casts and casts back the input and output tensor of a method Examples -------- >>> class A: ... @property ... def precision(self): ... return tf.float32 ... ... @cast_precision ... def f(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: ... return x ** 2 + y """ @wraps(func) def wrapper(self, *args, **kwargs): # only convert tensors returned_tensor = func( self, *[ safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args ], **{ kk: safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items() }, ) if isinstance(returned_tensor, tuple): return tuple( safe_cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION) for vv in returned_tensor ) else: return safe_cast_tensor( returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION ) return wrapper
[docs]def clear_session(): """Reset all state generated by DeePMD-kit.""" tf.reset_default_graph() # TODO: remove this line when data_requirement is not a global variable data_requirement.clear()