Source code for deepmd.pt_expt.model.polar_model

# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
    Any,
)

import torch
from torch.fx.experimental.proxy_tensor import (
    make_fx,
)

from deepmd.dpmodel.atomic_model import (
    DPPolarAtomicModel,
)
from deepmd.dpmodel.model.dp_model import (
    DPModelCommon,
)

from .make_model import (
    _pad_nlist_for_export,
    make_model,
)
from .model import (
    BaseModel,
)

[docs] DPPolarModel_ = make_model(DPPolarAtomicModel, T_Bases=(BaseModel,))
@BaseModel.register("polar")
[docs] class PolarModel(DPModelCommon, DPPolarModel_): def __init__( self, *args: Any, **kwargs: Any, ) -> None: DPModelCommon.__init__(self) DPPolarModel_.__init__(self, *args, **kwargs)
[docs] def forward( self, coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: model_ret = self.call_common( coord, atype, box, fparam=fparam, aparam=aparam, charge_spin=charge_spin, do_atomic_virial=do_atomic_virial, ) model_predict = {} model_predict["polar"] = model_ret["polarizability"] model_predict["global_polar"] = model_ret["polarizability_redu"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict
[docs] def forward_lower( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: model_ret = self.call_common_lower( extended_coord, extended_atype, nlist, mapping, fparam=fparam, aparam=aparam, charge_spin=charge_spin, do_atomic_virial=do_atomic_virial, ) model_predict = {} model_predict["polar"] = model_ret["polarizability"] model_predict["global_polar"] = model_ret["polarizability_redu"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict
[docs] def translated_output_def(self) -> dict[str, Any]: out_def_data = self.model_output_def().get_data() output_def = { "polar": out_def_data["polarizability"], "global_polar": out_def_data["polarizability_redu"], } if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def
[docs] def forward_lower_exportable( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self def fn( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None, fparam: torch.Tensor | None, aparam: torch.Tensor | None, charge_spin: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, nlist, mapping, fparam=fparam, aparam=aparam, charge_spin=charge_spin, do_atomic_virial=do_atomic_virial, ) # See make_model.py for the rationale of the pad + monkeypatch. _orig_need_sort = model.need_sorted_nlist_for_lower model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) try: traced = make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam, charge_spin, ) finally: model.need_sorted_nlist_for_lower = _orig_need_sort return traced