# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
from deepmd.dpmodel.model.spin_model import (
SpinModel,
)
from deepmd.utils.spin import (
Spin,
)
[docs]
def get_standard_model(data: dict) -> EnergyModel:
"""Get a EnergyModel from a dictionary.
Parameters
----------
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
fitting_type = data["fitting_net"].pop("type")
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")
return EnergyModel(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
)
[docs]
def get_spin_model(data: dict) -> SpinModel:
"""Get a spin model from a dictionary.
Parameters
----------
data : dict
The data to construct the model.
"""
# include virtual spin and placeholder types
data["type_map"] += [item + "_spin" for item in data["type_map"]]
spin = Spin(
use_spin=data["spin"]["use_spin"],
virtual_scale=data["spin"]["virtual_scale"],
)
pair_exclude_types = spin.get_pair_exclude_types(
exclude_types=data.get("pair_exclude_types", None)
)
data["pair_exclude_types"] = pair_exclude_types
# for descriptor data stat
data["descriptor"]["exclude_types"] = pair_exclude_types
atom_exclude_types = spin.get_atom_exclude_types(
exclude_types=data.get("atom_exclude_types", None)
)
data["atom_exclude_types"] = atom_exclude_types
if "env_protection" not in data["descriptor"]:
data["descriptor"]["env_protection"] = 1e-6
if data["descriptor"]["type"] in ["se_e2_a"]:
# only expand sel for se_e2_a
data["descriptor"]["sel"] += data["descriptor"]["sel"]
backbone_model = get_standard_model(data)
return SpinModel(backbone_model=backbone_model, spin=spin)
[docs]
def get_model(data: dict):
"""Get a model from a dictionary.
Parameters
----------
data : dict
The data to construct the model.
"""
if "spin" in data:
return get_spin_model(data)
else:
return get_standard_model(data)