# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
TYPE_CHECKING,
Optional,
)
from deepmd.infer.deep_tensor import (
DeepTensor,
)
if TYPE_CHECKING:
from pathlib import (
Path,
)
[docs]class DeepWFC(DeepTensor):
"""Constructor.
Parameters
----------
model_file : Path
The name of the frozen model file.
load_prefix: str
The prefix in the load computational graph
default_tf_graph : bool
If uses the default tf graph, otherwise build a new tf graph for evaluation
input_map : dict, optional
The input map for tf.import_graph_def. Only work with default tf graph
Warnings
--------
For developers: `DeepTensor` initializer must be called at the end after
`self.tensors` are modified because it uses the data in `self.tensors` dict.
Do not chanage the order!
"""
def __init__(
self,
model_file: "Path",
load_prefix: str = "load",
default_tf_graph: bool = False,
input_map: Optional[dict] = None,
) -> None:
# use this in favor of dict update to move attribute from class to
# instance namespace
self.tensors = dict(
{
# output tensor
"t_tensor": "o_wfc:0",
},
**self.tensors,
)
DeepTensor.__init__(
self,
model_file,
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
)
[docs] def get_dim_fparam(self) -> int:
"""Unsupported in this model."""
raise NotImplementedError("This model type does not support this attribute")
[docs] def get_dim_aparam(self) -> int:
"""Unsupported in this model."""
raise NotImplementedError("This model type does not support this attribute")