Source code for deepmd.pt.model.model.dp_zbl_model

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Dict,
    Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
    DPZBLLinearEnergyAtomicModel,
)
from deepmd.pt.model.model.model import (
    BaseModel,
)

from .dp_model import (
    DPModelCommon,
)
from .make_model import (
    make_model,
)

[docs] DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel)
@BaseModel.register("zbl")
[docs] class DPZBLModel(DPZBLModel_):
[docs] model_type = "ener"
def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs)
[docs] def forward( self, coord, atype, box: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( coord, atype, box, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, ) model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] if self.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) else: model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict
@torch.jit.export
[docs] def forward_lower( self, extended_coord, extended_atype, nlist, mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): model_ret = self.forward_common_lower( extended_coord, extended_atype, nlist, mapping=mapping, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, ) model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] if self.do_grad_r("energy"): model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( -3 ) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] return model_predict
@classmethod
[docs] def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ local_jdata_cpy = local_jdata.copy() local_jdata_cpy["dpmodel"] = DPModelCommon.update_sel( global_jdata, local_jdata["dpmodel"] ) return local_jdata_cpy