Source code for deepmd.infer.deep_wfc

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    TYPE_CHECKING,
    Optional,
)

from deepmd.infer.deep_tensor import (
    DeepTensor,
)

if TYPE_CHECKING:
    from pathlib import (
        Path,
    )


[docs]class DeepWFC(DeepTensor): """Constructor. Parameters ---------- model_file : Path The name of the frozen model file. load_prefix: str The prefix in the load computational graph default_tf_graph : bool If uses the default tf graph, otherwise build a new tf graph for evaluation input_map : dict, optional The input map for tf.import_graph_def. Only work with default tf graph Warnings -------- For developers: `DeepTensor` initializer must be called at the end after `self.tensors` are modified because it uses the data in `self.tensors` dict. Do not chanage the order! """ def __init__( self, model_file: "Path", load_prefix: str = "load", default_tf_graph: bool = False, input_map: Optional[dict] = None, ) -> None: # use this in favor of dict update to move attribute from class to # instance namespace self.tensors = dict( { # output tensor "t_tensor": "o_wfc:0", }, **self.tensors, ) DeepTensor.__init__( self, model_file, load_prefix=load_prefix, default_tf_graph=default_tf_graph, input_map=input_map, )
[docs] def get_dim_fparam(self) -> int: """Unsupported in this model.""" raise NotImplementedError("This model type does not support this attribute")
[docs] def get_dim_aparam(self) -> int: """Unsupported in this model.""" raise NotImplementedError("This model type does not support this attribute")