Source code for deepmd.pt.infer.inference

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from copy import (
    deepcopy,
)

import torch

from deepmd.pt.model.model import (
    get_model,
)
from deepmd.pt.train.wrapper import (
    ModelWrapper,
)
from deepmd.pt.utils.env import (
    DEVICE,
    JIT,
)

if torch.__version__.startswith("2"):
    import torch._dynamo
[docs] log = logging.getLogger(__name__)
[docs] class Tester: def __init__( self, model_ckpt, head=None, ): """Construct a DeePMD tester. Args: - config: The Dict-like configuration with training options. """ # Model state_dict = torch.load(model_ckpt, map_location=DEVICE) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"] self.multi_task = "model_dict" in model_params if self.multi_task: assert head is not None, "Head must be specified in multitask mode!" self.head = head assert head in model_params["model_dict"], ( f"Specified head {head} not found in model {model_ckpt}! " f"Available ones are {list(model_params['model_dict'].keys())}." ) model_params = model_params["model_dict"][head] state_dict_head = {"_extra_state": state_dict["_extra_state"]} for item in state_dict: if f"model.{head}." in item: state_dict_head[ item.replace(f"model.{head}.", "model.Default.") ] = state_dict[item].clone() state_dict = state_dict_head self.model_params = deepcopy(model_params) model_params["resuming"] = True self.model = get_model(model_params).to(DEVICE) # Model Wrapper self.wrapper = ModelWrapper(self.model) # inference only if JIT: self.wrapper = torch.jit.script(self.wrapper) self.wrapper.load_state_dict(state_dict)