# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
Any,
)
import torch
from torch.fx.experimental.proxy_tensor import (
make_fx,
)
from deepmd.dpmodel.atomic_model import (
DPPolarAtomicModel,
)
from deepmd.dpmodel.model.dp_model import (
DPModelCommon,
)
from .make_model import (
_pad_nlist_for_export,
make_model,
)
from .model import (
BaseModel,
)
[docs]
DPPolarModel_ = make_model(DPPolarAtomicModel, T_Bases=(BaseModel,))
@BaseModel.register("polar")
[docs]
class PolarModel(DPModelCommon, DPPolarModel_):
def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
DPModelCommon.__init__(self)
DPPolarModel_.__init__(self, *args, **kwargs)
[docs]
def forward(
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
charge_spin: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
model_ret = self.call_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict
[docs]
def forward_lower(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
charge_spin: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
model_ret = self.call_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict
[docs]
def translated_output_def(self) -> dict[str, Any]:
out_def_data = self.model_output_def().get_data()
output_def = {
"polar": out_def_data["polarizability"],
"global_polar": out_def_data["polarizability_redu"],
}
if "mask" in out_def_data:
output_def["mask"] = out_def_data["mask"]
return output_def
[docs]
def forward_lower_exportable(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
charge_spin: torch.Tensor | None = None,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self
def fn(
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None,
fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
charge_spin: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
do_atomic_virial=do_atomic_virial,
)
# See make_model.py for the rationale of the pad + monkeypatch.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
charge_spin,
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced