Source code for deepmd.backend.tensorflow

# SPDX-License-Identifier: LGPL-3.0-or-later
from importlib.util import (
    find_spec,
)
from typing import (
    TYPE_CHECKING,
    Callable,
    ClassVar,
    List,
    Type,
)

from deepmd.backend.backend import (
    Backend,
)

if TYPE_CHECKING:
    from argparse import (
        Namespace,
    )

    from deepmd.infer.deep_eval import (
        DeepEvalBackend,
    )
    from deepmd.utils.neighbor_stat import (
        NeighborStat,
    )


@Backend.register("tf")
@Backend.register("tensorflow")
[docs] class TensorFlowBackend(Backend): """TensorFlow backend."""
[docs] name = "TensorFlow"
"""The formal name of the backend."""
[docs] features: ClassVar[Backend.Feature] = ( Backend.Feature.ENTRY_POINT | Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO )
"""The features of the backend."""
[docs] suffixes: ClassVar[List[str]] = [".pb"]
"""The suffixes of the backend."""
[docs] def is_available(self) -> bool: """Check if the backend is available. Returns ------- bool Whether the backend is available. """ # deepmd.env imports expensive numpy # avoid import outside the method from deepmd.env import ( GLOBAL_CONFIG, ) return ( find_spec("tensorflow") is not None and GLOBAL_CONFIG["enable_tensorflow"] != "0" )
@property
[docs] def entry_point_hook(self) -> Callable[["Namespace"], None]: """The entry point hook of the backend. Returns ------- Callable[[Namespace], None] The entry point hook of the backend. """ from deepmd.tf.entrypoints.main import main as deepmd_main return deepmd_main
@property
[docs] def deep_eval(self) -> Type["DeepEvalBackend"]: """The Deep Eval backend of the backend. Returns ------- type[DeepEvalBackend] The Deep Eval backend of the backend. """ from deepmd.tf.infer.deep_eval import DeepEval as DeepEvalTF return DeepEvalTF
@property
[docs] def neighbor_stat(self) -> Type["NeighborStat"]: """The neighbor statistics of the backend. Returns ------- type[NeighborStat] The neighbor statistics of the backend. """ from deepmd.tf.utils.neighbor_stat import ( NeighborStat, ) return NeighborStat
@property
[docs] def serialize_hook(self) -> Callable[[str], dict]: """The serialize hook to convert the model file to a dictionary. Returns ------- Callable[[str], dict] The serialize hook of the backend. """ from deepmd.tf.utils.serialization import ( serialize_from_file, ) return serialize_from_file
@property
[docs] def deserialize_hook(self) -> Callable[[str, dict], None]: """The deserialize hook to convert the dictionary to a model file. Returns ------- Callable[[str, dict], None] The deserialize hook of the backend. """ from deepmd.tf.utils.serialization import ( deserialize_to_file, ) return deserialize_to_file