Source code for deepmd.dpmodel.model.model

# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
    Any,
)

from deepmd.dpmodel.atomic_model.dp_atomic_model import (
    DPAtomicModel,
)
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
    PairTabAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
    BaseDescriptor,
)
from deepmd.dpmodel.fitting.base_fitting import (
    BaseFitting,
)
from deepmd.dpmodel.fitting.ener_fitting import (
    EnergyFittingNet,
)
from deepmd.dpmodel.model.base_model import (
    BaseModel,
)
from deepmd.dpmodel.model.dipole_model import (
    DipoleModel,
)
from deepmd.dpmodel.model.dos_model import (
    DOSModel,
)
from deepmd.dpmodel.model.dp_zbl_model import (
    DPZBLModel,
)
from deepmd.dpmodel.model.ener_model import (
    EnergyModel,
)
from deepmd.dpmodel.model.polar_model import (
    PolarModel,
)
from deepmd.dpmodel.model.property_model import (
    PropertyModel,
)
from deepmd.dpmodel.model.spin_model import (
    SpinModel,
)
from deepmd.utils.spin import (
    Spin,
)


[docs] def _get_standard_model_components( data: dict[str, Any], ntypes: int ) -> tuple[BaseDescriptor, BaseFitting, str]: # descriptor data["descriptor"]["ntypes"] = ntypes data["descriptor"]["type_map"] = copy.deepcopy(data["type_map"]) descriptor = BaseDescriptor(**data["descriptor"]) # fitting fitting_net = data.get("fitting_net", {}) fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["type_map"] = copy.deepcopy(data["type_map"]) fitting_net["mixed_types"] = descriptor.mixed_types() if fitting_net["type"] in ["dipole", "polar"]: fitting_net["embedding_width"] = descriptor.get_dim_emb() fitting_net["dim_descrpt"] = descriptor.get_dim_out() grad_force = "direct" not in fitting_net["type"] if not grad_force: fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True fitting = BaseFitting(**fitting_net) return descriptor, fitting, fitting_net["type"]
[docs] def get_standard_model(data: dict) -> EnergyModel: """Get a EnergyModel from a dictionary. Parameters ---------- data : dict The data to construct the model. """ if "type_embedding" in data: raise ValueError( "In the DP backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details." ) data = copy.deepcopy(data) ntypes = len(data["type_map"]) descriptor, fitting, fitting_net_type = _get_standard_model_components(data, ntypes) atom_exclude_types = data.get("atom_exclude_types", []) pair_exclude_types = data.get("pair_exclude_types", []) if fitting_net_type == "dipole": modelcls = DipoleModel elif fitting_net_type == "polar": modelcls = PolarModel elif fitting_net_type == "dos": modelcls = DOSModel elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": modelcls = PropertyModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") model = modelcls( descriptor=descriptor, fitting=fitting, type_map=data["type_map"], atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, ) return model
[docs] def get_zbl_model(data: dict) -> DPZBLModel: data = copy.deepcopy(data) data["descriptor"]["ntypes"] = len(data["type_map"]) data["descriptor"]["type_map"] = data["type_map"] descriptor = BaseDescriptor(**data["descriptor"]) fitting_type = data["fitting_net"].pop("type") data["fitting_net"]["type_map"] = data["type_map"] if fitting_type == "ener": fitting = EnergyFittingNet( ntypes=descriptor.get_ntypes(), dim_descrpt=descriptor.get_dim_out(), mixed_types=descriptor.mixed_types(), **data["fitting_net"], ) else: raise ValueError(f"Unknown fitting type {fitting_type}") dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) # pairtab filepath = data["use_srtab"] pt_model = PairTabAtomicModel( filepath, descriptor.get_rcut(), descriptor.get_sel(), type_map=data["type_map"], ) rmin = data["sw_rmin"] rmax = data["sw_rmax"] atom_exclude_types = data.get("atom_exclude_types", []) pair_exclude_types = data.get("pair_exclude_types", []) return DPZBLModel( dp_model, pt_model, rmin, rmax, type_map=data["type_map"], atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, )
[docs] def get_spin_model(data: dict) -> SpinModel: """Get a spin model from a dictionary. Parameters ---------- data : dict The data to construct the model. """ data = copy.deepcopy(data) # include virtual spin and placeholder types data["type_map"] += [item + "_spin" for item in data["type_map"]] spin = Spin( use_spin=data["spin"]["use_spin"], virtual_scale=data["spin"]["virtual_scale"], ) pair_exclude_types = spin.get_pair_exclude_types( exclude_types=data.get("pair_exclude_types", None) ) data["pair_exclude_types"] = pair_exclude_types # for descriptor data stat data["descriptor"]["exclude_types"] = pair_exclude_types atom_exclude_types = spin.get_atom_exclude_types( exclude_types=data.get("atom_exclude_types", None) ) data["atom_exclude_types"] = atom_exclude_types if "env_protection" not in data["descriptor"]: data["descriptor"]["env_protection"] = 1e-6 if data["descriptor"]["type"] in ["se_e2_a"]: # only expand sel for se_e2_a data["descriptor"]["sel"] += data["descriptor"]["sel"] backbone_model = get_standard_model(data) return SpinModel(backbone_model=backbone_model, spin=spin)
[docs] def get_model(data: dict) -> BaseModel: """Get a model from a dictionary. Parameters ---------- data : dict The data to construct the model. """ model_type = data.get("type", "standard") if model_type == "standard": if "spin" in data: return get_spin_model(data) elif "use_srtab" in data: return get_zbl_model(data) else: return get_standard_model(data) else: return BaseModel.get_class_by_type(model_type).get_model(data)