# SPDX-License-Identifier: LGPL-3.0-or-later
"""Module providing compatibility between `0.x.x` and `1.x.x` input versions."""
import json
import warnings
from collections.abc import (
Sequence,
)
from pathlib import (
Path,
)
from typing import (
Any,
)
import numpy as np
from deepmd.common import (
j_deprecated,
)
[docs]
def _model(jdata: dict[str, Any], smooth: bool) -> dict[str, dict[str, Any]]:
"""Convert data to v1 input for non-smooth model.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
smooth : bool
whether to use smooth or non-smooth descriptor version
Returns
-------
dict[str, dict[str, Any]]
dictionary with model input parameters and sub-dictionaries for descriptor and
fitting net
"""
model = {}
model["descriptor"] = (
_smth_descriptor(jdata) if smooth else _nonsmth_descriptor(jdata)
)
model["fitting_net"] = _fitting_net(jdata)
return model
[docs]
def _nonsmth_descriptor(jdata: dict[str, Any]) -> dict[str, Any]:
"""Convert data to v1 input for non-smooth descriptor.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
Returns
-------
dict[str, Any]
dict with descriptor parameters
"""
descriptor = {}
descriptor["type"] = "loc_frame"
_jcopy(jdata, descriptor, ("sel_a", "sel_r", "rcut", "axis_rule"))
return descriptor
[docs]
def _smth_descriptor(jdata: dict[str, Any]) -> dict[str, Any]:
"""Convert data to v1 input for smooth descriptor.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
Returns
-------
dict[str, Any]
dict with descriptor parameters
"""
descriptor = {}
seed = jdata.get("seed", None)
if seed is not None:
descriptor["seed"] = seed
descriptor["type"] = "se_a"
descriptor["sel"] = jdata["sel_a"]
_jcopy(jdata, descriptor, ("rcut",))
descriptor["rcut_smth"] = jdata.get("rcut_smth", descriptor["rcut"])
descriptor["neuron"] = jdata["filter_neuron"]
descriptor["axis_neuron"] = j_deprecated(jdata, "axis_neuron", ["n_axis_neuron"])
descriptor["resnet_dt"] = False
if "resnet_dt" in jdata:
descriptor["resnet_dt"] = jdata["filter_resnet_dt"]
return descriptor
[docs]
def _fitting_net(jdata: dict[str, Any]) -> dict[str, Any]:
"""Convert data to v1 input for fitting net.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
Returns
-------
dict[str, Any]
dict with fitting net parameters
"""
fitting_net = {}
seed = jdata.get("seed", None)
if seed is not None:
fitting_net["seed"] = seed
fitting_net["neuron"] = j_deprecated(jdata, "fitting_neuron", ["n_neuron"])
fitting_net["resnet_dt"] = True
if "resnet_dt" in jdata:
fitting_net["resnet_dt"] = jdata["resnet_dt"]
if "fitting_resnet_dt" in jdata:
fitting_net["resnet_dt"] = jdata["fitting_resnet_dt"]
return fitting_net
[docs]
def _learning_rate(jdata: dict[str, Any]) -> dict[str, Any]:
"""Convert data to v1 input for learning rate section.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
Returns
-------
dict[str, Any]
dict with learning rate parameters
"""
learning_rate = {}
learning_rate["type"] = "exp"
_jcopy(jdata, learning_rate, ("decay_steps", "decay_rate", "start_lr"))
return learning_rate
[docs]
def _loss(jdata: dict[str, Any]) -> dict[str, Any]:
"""Convert data to v1 input for loss function.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
Returns
-------
dict[str, Any]
dict with loss function parameters
"""
loss: dict[str, Any] = {}
_jcopy(
jdata,
loss,
(
"start_pref_e",
"limit_pref_e",
"start_pref_f",
"limit_pref_f",
"start_pref_v",
"limit_pref_v",
),
)
if "start_pref_ae" in jdata:
loss["start_pref_ae"] = jdata["start_pref_ae"]
if "limit_pref_ae" in jdata:
loss["limit_pref_ae"] = jdata["limit_pref_ae"]
return loss
[docs]
def _training(jdata: dict[str, Any]) -> dict[str, Any]:
"""Convert data to v1 input for training.
Parameters
----------
jdata : dict[str, Any]
parsed input json/yaml data
Returns
-------
dict[str, Any]
dict with training parameters
"""
training = {}
seed = jdata.get("seed", None)
if seed is not None:
training["seed"] = seed
_jcopy(jdata, training, ("systems", "set_prefix", "stop_batch", "batch_size"))
training["disp_file"] = "lcurve.out"
if "disp_file" in jdata:
training["disp_file"] = jdata["disp_file"]
training["disp_freq"] = jdata["disp_freq"]
training["numb_test"] = jdata["numb_test"]
training["save_freq"] = jdata["save_freq"]
training["save_ckpt"] = jdata["save_ckpt"]
training["disp_training"] = jdata["disp_training"]
training["time_training"] = jdata["time_training"]
if "profiling" in jdata:
training["profiling"] = jdata["profiling"]
if training["profiling"]:
training["profiling_file"] = jdata["profiling_file"]
return training
[docs]
def _jcopy(src: dict[str, Any], dst: dict[str, Any], keys: Sequence[str]) -> None:
"""Copy specified keys from one dict to another.
Parameters
----------
src : dict[str, Any]
source dictionary
dst : dict[str, Any]
destination dictionary, will be modified in place
keys : Sequence[str]
list of keys to copy
"""
for k in keys:
if k in src:
dst[k] = src[k]
[docs]
def remove_decay_rate(jdata: dict[str, Any]) -> None:
"""Convert decay_rate to stop_lr.
Parameters
----------
jdata : dict[str, Any]
input data
"""
lr = jdata["learning_rate"]
if "decay_rate" in lr:
decay_rate = lr["decay_rate"]
start_lr = lr["start_lr"]
stop_step = jdata["training"]["stop_batch"]
decay_steps = lr["decay_steps"]
stop_lr = np.exp(np.log(decay_rate) * (stop_step / decay_steps)) * start_lr
lr["stop_lr"] = stop_lr
lr.pop("decay_rate")
[docs]
def deprecate_numb_test(
jdata: dict[str, Any], warning: bool = True, dump: str | Path | None = None
) -> dict[str, Any]:
"""Deprecate `numb_test` since v2.1. It has taken no effect since v2.0.
See `#1243 <https://github.com/deepmodeling/deepmd-kit/discussions/1243>`_.
Parameters
----------
jdata : dict[str, Any]
loaded json/yaml file
warning : bool, optional
whether to show deprecation warning, by default True
dump : Optional[Union[str, Path]], optional
whether to dump converted file, by default None
Returns
-------
dict[str, Any]
converted output
"""
try:
jdata.get("training", {}).pop("numb_test")
except KeyError:
pass
else:
if warning:
warnings.warn(
"The argument training->numb_test has been deprecated since v2.0.0. "
"Use training->validation_data->batch_size instead."
)
if dump is not None:
with open(dump, "w") as fp:
json.dump(jdata, fp, indent=4)
return jdata
[docs]
def migrate_training_warmup(
jdata: dict[str, Any], warning: bool = True
) -> dict[str, Any]:
"""
Migrate legacy warmup settings from training to learning_rate.
Parameters
----------
jdata : dict[str, Any]
Input configuration dictionary.
warning : bool, optional
Whether to show a deprecation warning, by default True.
Returns
-------
dict[str, Any]
Updated configuration dictionary.
"""
training = jdata.get("training")
if not isinstance(training, dict):
return jdata
warmup_keys = ("warmup_steps", "warmup_ratio", "warmup_start_factor")
legacy_keys = [key for key in warmup_keys if key in training]
if not legacy_keys:
return jdata
lr = jdata.get("learning_rate")
if not isinstance(lr, dict):
for key in legacy_keys:
training.pop(key)
if warning:
warnings.warn(
"Found legacy warmup settings under training, but learning_rate "
"is missing or invalid. The warmup keys were removed from training."
)
return jdata
moved_keys = []
conflict_keys = []
# === Step 1. Check for conflicts first (read-only pass) ===
for key in legacy_keys:
if key in lr:
conflict_keys.append(key)
# Raise error if there are conflicting definitions before mutating
if conflict_keys:
raise ValueError(
"Conflicting warmup settings found in both 'training' and "
f"'learning_rate': {', '.join(conflict_keys)}. "
"Please define warmup settings only in 'learning_rate'."
)
# === Step 2. Move legacy warmup keys ===
for key in legacy_keys:
value = training.pop(key)
lr[key] = value
moved_keys.append(key)
if warning and moved_keys:
warnings.warn(
"Legacy warmup settings under training were moved to learning_rate: "
f"{', '.join(moved_keys)}."
)
return jdata
[docs]
def convert_optimizer_v31_to_v32(
jdata: dict[str, Any], warning: bool = True
) -> dict[str, Any]:
"""Convert optimizer format from v3.1 to v3.2.
v3.1 format: optimizer parameters (opt_type, kf_blocksize, etc.) in training section.
v3.2 format: separate optimizer section with type field.
Parameters
----------
jdata : dict[str, Any]
loaded json/yaml file
warning : bool, optional
whether to show deprecation warning, by default True
Returns
-------
dict[str, Any]
converted output with optimizer section
"""
# Default optimizer values (must match argcheck.py defaults)
default_optimizer = {
"type": "Adam",
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"weight_decay": 0.0,
}
training_cfg = jdata.get("training", {})
optimizer_cfg = jdata.get("optimizer", {})
# === Step 1. Extract legacy optimizer parameters from training ===
optimizer_keys = [
"opt_type",
"kf_blocksize",
"kf_start_pref_e",
"kf_limit_pref_e",
"kf_start_pref_f",
"kf_limit_pref_f",
"weight_decay",
"momentum",
"muon_momentum",
"adam_beta1",
"adam_beta2",
"lr_adjust",
"lr_adjust_coeff",
"muon_2d_only",
"min_2d_dim",
]
has_legacy_optimizer = any(key in training_cfg for key in optimizer_keys)
if has_legacy_optimizer:
extracted_cfg = {}
for key in optimizer_keys:
if key in training_cfg:
extracted_cfg[key] = training_cfg.pop(key)
# Convert opt_type to type for new format
if "opt_type" in extracted_cfg:
extracted_cfg["type"] = extracted_cfg.pop("opt_type")
# Merge with existing optimizer config (conversion takes precedence)
optimizer_cfg = {**optimizer_cfg, **extracted_cfg}
if warning:
warnings.warn(
"Placing optimizer parameters (opt_type, kf_blocksize, etc.) in the training section "
"is deprecated. Use a separate 'optimizer' section with 'type' field instead.",
DeprecationWarning,
stacklevel=2,
)
# === Step 2. Fill in missing defaults ===
if "type" not in optimizer_cfg:
optimizer_cfg["type"] = default_optimizer["type"]
# Fill in defaults for Adam optimizer type
if optimizer_cfg["type"] in ("Adam", "AdamW"):
for key, value in default_optimizer.items():
if key not in optimizer_cfg:
optimizer_cfg[key] = value
# Set/update the optimizer section
jdata["optimizer"] = optimizer_cfg
return jdata