Source code for deepmd.pt.model.model

# SPDX-License-Identifier: LGPL-3.0-or-later
"""The model that takes the coordinates, cell and atom types as input
and predicts some property. The models are automatically generated from
atomic models by the `deepmd.dpmodel.make_model` method.

The `make_model` method does the reduction, auto-differentiation and
communication of the atomic properties according to output variable
definition `deepmd.dpmodel.OutputVariableDef`.

All models should be inherited from :class:`deepmd.pt.model.model.model.BaseModel`.
Models generated by `make_model` have already done it.
"""

import copy
import json

import numpy as np

from deepmd.pt.model.atomic_model import (
    DPAtomicModel,
    PairTabAtomicModel,
)
from deepmd.pt.model.descriptor.base_descriptor import (
    BaseDescriptor,
)
from deepmd.pt.model.task import (
    BaseFitting,
)
from deepmd.utils.spin import (
    Spin,
)

from .dipole_model import (
    DipoleModel,
)
from .dos_model import (
    DOSModel,
)
from .dp_model import (
    DPModelCommon,
)
from .dp_zbl_model import (
    DPZBLModel,
)
from .ener_model import (
    EnergyModel,
)
from .frozen import (
    FrozenModel,
)
from .make_hessian_model import (
    make_hessian_model,
)
from .make_model import (
    make_model,
)
from .model import (
    BaseModel,
)
from .polar_model import (
    PolarModel,
)
from .spin_model import (
    SpinEnergyModel,
    SpinModel,
)


def get_spin_model(model_params):
    model_params = copy.deepcopy(model_params)
    if not model_params["spin"]["use_spin"] or isinstance(
        model_params["spin"]["use_spin"][0], int
    ):
        use_spin = np.full(len(model_params["type_map"]), False)
        use_spin[model_params["spin"]["use_spin"]] = True
        model_params["spin"]["use_spin"] = use_spin.tolist()
    # include virtual spin and placeholder types
    model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]]
    spin = Spin(
        use_spin=model_params["spin"]["use_spin"],
        virtual_scale=model_params["spin"]["virtual_scale"],
    )
    pair_exclude_types = spin.get_pair_exclude_types(
        exclude_types=model_params.get("pair_exclude_types", None)
    )
    model_params["pair_exclude_types"] = pair_exclude_types
    # for descriptor data stat
    model_params["descriptor"]["exclude_types"] = pair_exclude_types
    atom_exclude_types = spin.get_atom_exclude_types(
        exclude_types=model_params.get("atom_exclude_types", None)
    )
    model_params["atom_exclude_types"] = atom_exclude_types
    if (
        "env_protection" not in model_params["descriptor"]
        or model_params["descriptor"]["env_protection"] == 0.0
    ):
        model_params["descriptor"]["env_protection"] = 1e-6
    if model_params["descriptor"]["type"] in ["se_e2_a"]:
        # only expand sel for se_e2_a
        model_params["descriptor"]["sel"] += model_params["descriptor"]["sel"]
    backbone_model = get_standard_model(model_params)
    return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_zbl_model(model_params):
    model_params = copy.deepcopy(model_params)
    ntypes = len(model_params["type_map"])
    # descriptor
    model_params["descriptor"]["ntypes"] = ntypes
    descriptor = BaseDescriptor(**model_params["descriptor"])
    # fitting
    fitting_net = model_params.get("fitting_net", None)
    fitting_net["type"] = fitting_net.get("type", "ener")
    fitting_net["ntypes"] = descriptor.get_ntypes()
    fitting_net["mixed_types"] = descriptor.mixed_types()
    fitting_net["embedding_width"] = descriptor.get_dim_out()
    fitting_net["dim_descrpt"] = descriptor.get_dim_out()
    grad_force = "direct" not in fitting_net["type"]
    if not grad_force:
        fitting_net["out_dim"] = descriptor.get_dim_emb()
        if "ener" in fitting_net["type"]:
            fitting_net["return_energy"] = True
    fitting = BaseFitting(**fitting_net)
    dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
    # pairtab
    filepath = model_params["use_srtab"]
    pt_model = PairTabAtomicModel(
        filepath,
        model_params["descriptor"]["rcut"],
        model_params["descriptor"]["sel"],
        type_map=model_params["type_map"],
    )

    rmin = model_params["sw_rmin"]
    rmax = model_params["sw_rmax"]
    atom_exclude_types = model_params.get("atom_exclude_types", [])
    pair_exclude_types = model_params.get("pair_exclude_types", [])
    return DPZBLModel(
        dp_model,
        pt_model,
        rmin,
        rmax,
        type_map=model_params["type_map"],
        atom_exclude_types=atom_exclude_types,
        pair_exclude_types=pair_exclude_types,
    )


def get_standard_model(model_params):
    model_params_old = model_params
    model_params = copy.deepcopy(model_params)
    ntypes = len(model_params["type_map"])
    # descriptor
    model_params["descriptor"]["ntypes"] = ntypes
    descriptor = BaseDescriptor(**model_params["descriptor"])
    # fitting
    fitting_net = model_params.get("fitting_net", None)
    fitting_net["type"] = fitting_net.get("type", "ener")
    fitting_net["ntypes"] = descriptor.get_ntypes()
    fitting_net["mixed_types"] = descriptor.mixed_types()
    if fitting_net["type"] in ["dipole", "polar"]:
        fitting_net["embedding_width"] = descriptor.get_dim_emb()
    fitting_net["dim_descrpt"] = descriptor.get_dim_out()
    grad_force = "direct" not in fitting_net["type"]
    if not grad_force:
        fitting_net["out_dim"] = descriptor.get_dim_emb()
        if "ener" in fitting_net["type"]:
            fitting_net["return_energy"] = True
    fitting = BaseFitting(**fitting_net)
    atom_exclude_types = model_params.get("atom_exclude_types", [])
    pair_exclude_types = model_params.get("pair_exclude_types", [])

    if fitting_net["type"] == "dipole":
        modelcls = DipoleModel
    elif fitting_net["type"] == "polar":
        modelcls = PolarModel
    elif fitting_net["type"] == "dos":
        modelcls = DOSModel
    elif fitting_net["type"] in ["ener", "direct_force_ener"]:
        modelcls = EnergyModel
    else:
        raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}")

    model = modelcls(
        descriptor=descriptor,
        fitting=fitting,
        type_map=model_params["type_map"],
        atom_exclude_types=atom_exclude_types,
        pair_exclude_types=pair_exclude_types,
    )
    model.model_def_script = json.dumps(model_params_old)
    return model


[docs] def get_model(model_params): if "spin" in model_params: return get_spin_model(model_params) elif "use_srtab" in model_params: return get_zbl_model(model_params) else: return get_standard_model(model_params)
__all__ = [ "BaseModel", "get_model", "DPModelCommon", "EnergyModel", "FrozenModel", "SpinModel", "SpinEnergyModel", "DPZBLModel", "make_model", "make_hessian_model", ]