Source code for deepmd.pt.model.model.dos_model

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Dict,
    Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
    DPDOSAtomicModel,
)
from deepmd.pt.model.model.model import (
    BaseModel,
)

from .dp_model import (
    DPModelCommon,
)
from .make_model import (
    make_model,
)

[docs] DPDOSModel_ = make_model(DPDOSAtomicModel)
@BaseModel.register("dos")
[docs] class DOSModel(DPModelCommon, DPDOSModel_):
[docs] model_type = "dos"
def __init__( self, *args, **kwargs, ): DPModelCommon.__init__(self) DPDOSModel_.__init__(self, *args, **kwargs)
[docs] def forward( self, coord, atype, box: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( coord, atype, box, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, ) if self.get_fitting_net() is not None: model_predict = {} model_predict["atom_dos"] = model_ret["dos"] model_predict["dos"] = model_ret["dos_redu"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] else: model_predict = model_ret model_predict["updated_coord"] += coord return model_predict
@torch.jit.export
[docs] def get_numb_dos(self) -> int: """Get the number of DOS for DOSFittingNet.""" return self.get_fitting_net().dim_out
@torch.jit.export
[docs] def forward_lower( self, extended_coord, extended_atype, nlist, mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): model_ret = self.forward_common_lower( extended_coord, extended_atype, nlist, mapping, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, ) if self.get_fitting_net() is not None: model_predict = {} model_predict["atom_dos"] = model_ret["dos"] model_predict["dos"] = model_ret["dos_redu"] else: model_predict = model_ret return model_predict