Source code for deepmd.pt_expt.entrypoints.main

# SPDX-License-Identifier: LGPL-3.0-or-later
"""Training entrypoint for the pt_expt backend."""

import argparse
import json
import logging
import os
from pathlib import (
    Path,
)
from typing import (
    Any,
)

import h5py

from deepmd.dpmodel.utils.lmdb_data import (
    is_lmdb,
)
from deepmd.pt_expt.train import (
    training,
)
from deepmd.pt_expt.utils.lmdb_dataset import (
    LmdbDataSystem,
)
from deepmd.utils.argcheck import (
    normalize,
)
from deepmd.utils.compat import (
    update_deepmd_input,
)
from deepmd.utils.data_system import (
    DeepmdDataSystem,
    get_data,
    process_systems,
)
from deepmd.utils.path import (
    DPPath,
)

[docs] log = logging.getLogger(__name__)
[docs] def _detect_lmdb_path(systems_raw: Any) -> str | None: """Return the LMDB path when ``systems_raw`` is a scalar LMDB string. Returns ``None`` for non-LMDB inputs. Raises ``ValueError`` if ``systems_raw`` is a list containing any LMDB path, so both ``_get_neighbor_stat_data`` and ``_build_data_system`` fail with the same clear message instead of the opaque error from :func:`process_systems` / :class:`DeepmdData`. """ if isinstance(systems_raw, str) and is_lmdb(systems_raw): return systems_raw if isinstance(systems_raw, list) and any( isinstance(s, str) and is_lmdb(s) for s in systems_raw ): raise ValueError( "LMDB datasets must be passed as a scalar 'systems' string " "(e.g. 'systems': '/path/to/data.lmdb'); list-form systems " "with LMDB paths are not supported." ) return None
[docs] def _get_neighbor_stat_data( dataset_params: dict[str, Any], type_map: list[str] | None, ) -> Any: """Return a data proxy suitable for ``BaseModel.update_sel`` (neighbor stat). Routes a scalar LMDB ``systems`` path through dpmodel's ``make_neighbor_stat_data``; falls back to the legacy ``get_data`` for npy/HDF5 directories. """ lmdb_path = _detect_lmdb_path(dataset_params.get("systems")) if lmdb_path is not None: from deepmd.dpmodel.utils.lmdb_data import ( make_neighbor_stat_data, ) return make_neighbor_stat_data(lmdb_path, type_map) return get_data(dataset_params, 0, type_map, None)
[docs] def _build_data_system( dataset_params: dict[str, Any], type_map: list[str], seed: int | None = None, ) -> DeepmdDataSystem | LmdbDataSystem: """Build a data system from dataset config, routing LMDB paths to LmdbDataSystem. A scalar ``systems`` value pointing at an LMDB directory triggers the LMDB adapter; otherwise we fall through to the legacy :class:`DeepmdDataSystem` path with system expansion. """ systems_raw = dataset_params["systems"] lmdb_path = _detect_lmdb_path(systems_raw) if lmdb_path is not None: return LmdbDataSystem( lmdb_path=lmdb_path, type_map=type_map, batch_size=dataset_params["batch_size"], auto_prob_style=dataset_params.get("auto_prob"), seed=seed, ) systems = process_systems( systems_raw, patterns=dataset_params.get("rglob_patterns", None), ) return DeepmdDataSystem( systems=systems, batch_size=dataset_params["batch_size"], test_size=1, type_map=type_map, trn_all_set=True, sys_probs=dataset_params.get("sys_probs", None), auto_prob_style=dataset_params.get("auto_prob", "prob_sys_size"), )
[docs] def get_trainer( config: dict[str, Any], init_model: str | None = None, restart_model: str | None = None, finetune_model: str | None = None, finetune_links: dict | None = None, shared_links: dict | None = None, ) -> training.Trainer: """Build a :class:`training.Trainer` from a normalised config.""" model_params = config["model"] training_params = config["training"] multi_task = "model_dict" in model_params data_seed = training_params.get("seed", None) if not multi_task: type_map = model_params["type_map"] # ----- training data ------------------------------------------------ training_dataset_params = training_params["training_data"] train_data = _build_data_system( training_dataset_params, type_map, seed=data_seed ) # ----- validation data ---------------------------------------------- validation_data = None validation_dataset_params = training_params.get("validation_data", None) if validation_dataset_params is not None: validation_data = _build_data_system( validation_dataset_params, type_map, seed=data_seed ) # ----- stat file path ----------------------------------------------- stat_file_path = training_params.get("stat_file", None) if stat_file_path is not None: if not Path(stat_file_path).exists(): if stat_file_path.endswith((".h5", ".hdf5")): with h5py.File(stat_file_path, "w"): pass else: Path(stat_file_path).mkdir(parents=True, exist_ok=True) stat_file_path = DPPath(stat_file_path, "a") else: # Multi-task: build per-task data systems train_data = {} validation_data = {} stat_file_path = {} for model_key in model_params["model_dict"]: type_map = model_params["model_dict"][model_key]["type_map"] data_params = training_params["data_dict"][model_key] # training data train_data[model_key] = _build_data_system( data_params["training_data"], type_map, seed=data_seed ) # validation data vd_params = data_params.get("validation_data", None) if vd_params is not None: validation_data[model_key] = _build_data_system( vd_params, type_map, seed=data_seed ) else: validation_data[model_key] = None # stat file _sf = data_params.get("stat_file", None) if _sf is not None: if not Path(_sf).exists(): if _sf.endswith((".h5", ".hdf5")): with h5py.File(_sf, "w"): pass else: Path(_sf).mkdir(parents=True, exist_ok=True) stat_file_path[model_key] = DPPath(_sf, "a") else: stat_file_path[model_key] = None trainer = training.Trainer( config, train_data, stat_file_path=stat_file_path, validation_data=validation_data, init_model=init_model, restart_model=restart_model, finetune_model=finetune_model, finetune_links=finetune_links, shared_links=shared_links, ) return trainer
[docs] def train( input_file: str, init_model: str | None = None, restart: str | None = None, finetune: str | None = None, model_branch: str = "", use_pretrain_script: bool = False, skip_neighbor_stat: bool = False, output: str = "out.json", ) -> None: """Run training with the pt_expt backend. Parameters ---------- input_file : str Path to the JSON configuration file. init_model : str or None Path to a checkpoint to initialise weights from. restart : str or None Path to a checkpoint to restart training from. finetune : str or None Path to a pretrained checkpoint to fine-tune from. model_branch : str Branch to select from a multi-task pretrained model. use_pretrain_script : bool If True, copy descriptor/fitting params from the pretrained model. skip_neighbor_stat : bool Skip neighbour statistics calculation. output : str Where to dump the normalised config. """ import torch from deepmd.common import ( j_loader, ) from deepmd.pt_expt.utils.env import ( DEVICE, ) log.info("Configuration path: %s", input_file) config = j_loader(input_file) # suffix fix if init_model is not None and not init_model.endswith(".pt"): init_model += ".pt" if restart is not None and not restart.endswith(".pt"): restart += ".pt" # Multi-task detection and shared params preprocessing multi_task = "model_dict" in config.get("model", {}) shared_links = None if multi_task: from deepmd.pt_expt.utils.multi_task import ( preprocess_shared_params, ) config["model"], shared_links = preprocess_shared_params(config["model"]) assert "RANDOM" not in config["model"]["model_dict"], ( "Model name can not be 'RANDOM' in multi-task mode!" ) # update fine-tuning config finetune_links = None if finetune is not None: from deepmd.pt_expt.utils.finetune import ( get_finetune_rules, ) config["model"], finetune_links = get_finetune_rules( finetune, config["model"], model_branch=model_branch, change_model_params=use_pretrain_script, ) # update init_model config if --use-pretrain-script if init_model is not None and use_pretrain_script: init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] # argcheck config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") config = normalize(config, multi_task=multi_task) # neighbour stat if not skip_neighbor_stat: log.info( "Calculate neighbor statistics... " "(add --skip-neighbor-stat to skip this step)" ) from deepmd.pt_expt.model import ( BaseModel, ) if not multi_task: type_map = config["model"].get("type_map") train_data = _get_neighbor_stat_data( config["training"]["training_data"], type_map ) config["model"], _ = BaseModel.update_sel( train_data, type_map, config["model"] ) else: for model_key in config["model"]["model_dict"]: type_map = config["model"]["model_dict"][model_key]["type_map"] train_data = _get_neighbor_stat_data( config["training"]["data_dict"][model_key]["training_data"], type_map, ) config["model"]["model_dict"][model_key], _ = BaseModel.update_sel( train_data, type_map, config["model"]["model_dict"][model_key], ) with open(output, "w") as fp: json.dump(config, fp, indent=4) import torch.distributed as dist if os.environ.get("LOCAL_RANK") is not None: dist.init_process_group(backend="cuda:nccl,cpu:gloo") try: trainer = get_trainer( config, init_model, restart, finetune_model=finetune, finetune_links=finetune_links, shared_links=shared_links, ) trainer.run() finally: if dist.is_available() and dist.is_initialized(): dist.destroy_process_group()
[docs] def freeze( model: str, output: str = "frozen_model.pte", head: str | None = None, ) -> None: """Freeze a pt_expt checkpoint into a .pte exported model. Parameters ---------- model : str Path to the checkpoint file (.pt). output : str Path for the output .pte file. head : str or None Head to freeze in multi-task mode. """ import torch from deepmd.pt_expt.model.get_model import ( get_model, ) from deepmd.pt_expt.train.wrapper import ( ModelWrapper, ) from deepmd.pt_expt.utils.env import ( DEVICE, ) from deepmd.pt_expt.utils.serialization import ( deserialize_to_file, ) state_dict = torch.load(model, map_location=DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] extra_state = state_dict.get("_extra_state") if not isinstance(extra_state, dict) or "model_params" not in extra_state: raise ValueError( f"Unsupported checkpoint format at '{model}': missing " "'_extra_state.model_params' in model state dict." ) model_params = extra_state["model_params"] multi_task = "model_dict" in model_params if multi_task: if head is None: raise ValueError( "Multi-task model requires --head to specify which model to freeze. " f"Available heads: {list(model_params['model_dict'].keys())}" ) if head not in model_params["model_dict"]: raise ValueError( f"Head '{head}' not found. " f"Available: {list(model_params['model_dict'].keys())}" ) # Build full multi-task wrapper, load weights, extract single head model_dict = {} for key in model_params["model_dict"]: from copy import ( deepcopy, ) model_dict[key] = get_model(deepcopy(model_params["model_dict"][key])) wrapper = ModelWrapper(model_dict) wrapper.load_state_dict(state_dict) m = wrapper.model[head] single_model_params = model_params["model_dict"][head] else: m = get_model(model_params) wrapper = ModelWrapper(m) wrapper.load_state_dict(state_dict) single_model_params = model_params m.eval() model_dict_serialized = m.serialize() deserialize_to_file( output, {"model": model_dict_serialized, "model_def_script": single_model_params}, ) log.info("Saved frozen model to %s", output)
[docs] def change_bias( input_file: str, mode: str = "change", bias_value: list | None = None, datafile: str | None = None, system: str = ".", numb_batch: int = 0, model_branch: str | None = None, output: str | None = None, ) -> None: """Change the output bias of a pt_expt model. Parameters ---------- input_file : str Path to the model file (.pt checkpoint or .pte frozen model). mode : str ``"change"`` or ``"set"``. bias_value : list or None User-defined bias values (one per type). datafile : str or None File listing data system paths. system : str Data system path (used when *datafile* is None). numb_batch : int Number of batches for statistics (0 = all). model_branch : str or None Branch name for multi-task models. output : str or None Output file path. """ import torch from deepmd.common import ( expand_sys_str, ) from deepmd.dpmodel.common import ( to_numpy_array, ) from deepmd.pt_expt.model.get_model import ( get_model, ) from deepmd.pt_expt.train.training import ( get_additional_data_requirement, get_loss, model_change_out_bias, ) from deepmd.pt_expt.train.wrapper import ( ModelWrapper, ) from deepmd.pt_expt.utils.env import ( DEVICE, ) from deepmd.pt_expt.utils.serialization import ( deserialize_to_file, serialize_from_file, ) from deepmd.pt_expt.utils.stat import ( make_stat_input, ) if input_file.endswith(".pt"): old_state_dict = torch.load(input_file, map_location=DEVICE, weights_only=True) if "model" in old_state_dict: model_state_dict = old_state_dict["model"] else: model_state_dict = old_state_dict extra_state = model_state_dict.get("_extra_state") if not isinstance(extra_state, dict) or "model_params" not in extra_state: raise ValueError( f"Unsupported checkpoint format at '{input_file}': missing " "'_extra_state.model_params' in model state dict." ) model_params = extra_state["model_params"] elif input_file.endswith((".pte", ".pt2")): pte_data = serialize_from_file(input_file) from deepmd.pt_expt.model.model import ( BaseModel, ) model_to_change = BaseModel.deserialize(pte_data["model"]) model_params = pte_data.get("model_def_script") else: raise RuntimeError( "The model provided must be a checkpoint file with a .pt extension " "or a frozen model with a .pte/.pt2 extension" ) if mode == "change": bias_adjust_mode = "change-by-statistic" elif mode == "set": bias_adjust_mode = "set-by-statistic" else: raise ValueError(f"Unsupported mode '{mode}'. Expected 'change' or 'set'.") if input_file.endswith(".pt"): multi_task = "model_dict" in model_params if multi_task: raise NotImplementedError( "Multi-task change-bias is not yet supported for the pt_expt backend." ) type_map = model_params["type_map"] model = get_model(model_params) wrapper = ModelWrapper(model) wrapper.load_state_dict(model_state_dict) model_to_change = model if input_file.endswith((".pte", ".pt2")): type_map = model_to_change.get_type_map() if bias_value is not None: if "energy" not in model_to_change.model_output_type(): raise ValueError("User-defined bias is only available for energy models!") if len(bias_value) != len(type_map): raise ValueError( f"The number of elements in the bias ({len(bias_value)}) must match " f"the number of types in type_map ({len(type_map)}): {type_map}." ) old_bias = model_to_change.get_out_bias() bias_to_set = torch.tensor( bias_value, dtype=old_bias.dtype, device=old_bias.device ).view(old_bias.shape) model_to_change.set_out_bias(bias_to_set) log.info( f"Change output bias of {type_map!s} " f"from {to_numpy_array(old_bias).reshape(-1)!s} " f"to {to_numpy_array(bias_to_set).reshape(-1)!s}." ) else: if datafile is not None: with open(datafile) as datalist: all_sys = datalist.read().splitlines() else: all_sys = expand_sys_str(system) data_systems = process_systems(all_sys) data = DeepmdDataSystem( systems=data_systems, batch_size=1, test_size=1, rcut=model_to_change.get_rcut(), type_map=type_map, ) mock_loss = get_loss({"inference": True}, 1.0, len(type_map), model_to_change) data.add_data_requirements(mock_loss.label_requirement) data.add_data_requirements(get_additional_data_requirement(model_to_change)) if numb_batch != 0: nbatches = numb_batch else: # Cap at the minimum across systems so no system wraps and # overweights short systems (matching PT behavior). nbatches = min(data.get_nbatches()) sampled_data = make_stat_input(data, nbatches) model_to_change = model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode ) if input_file.endswith(".pt"): output_path = ( output if output is not None else input_file.replace(".pt", "_updated.pt") ) wrapper = ModelWrapper(model_to_change) if "model" in old_state_dict: old_state_dict["model"] = wrapper.state_dict() old_state_dict["model"]["_extra_state"] = extra_state else: old_state_dict = wrapper.state_dict() old_state_dict["_extra_state"] = extra_state torch.save(old_state_dict, output_path) elif input_file.endswith((".pte", ".pt2")): output_path = ( output if output is not None else input_file.replace(".pte", "_updated.pte").replace( ".pt2", "_updated.pt2" ) ) model_dict = model_to_change.serialize() deserialize_to_file( output_path, {"model": model_dict, "model_def_script": model_params} ) log.info(f"Saved model to {output_path}")
[docs] def main(args: list[str] | argparse.Namespace | None = None) -> None: """Entry point for the pt_expt backend CLI. Parameters ---------- args : list[str] | argparse.Namespace | None Command-line arguments or pre-parsed namespace. """ from deepmd.loggers.loggers import ( set_log_handles, ) from deepmd.main import ( parse_args, ) if not isinstance(args, argparse.Namespace): FLAGS = parse_args(args=args) else: FLAGS = args set_log_handles( FLAGS.log_level, Path(FLAGS.log_path) if FLAGS.log_path else None, mpi_log=None, ) log.info("DeePMD-kit backend: pt_expt (PyTorch Exportable)") if FLAGS.command == "train": train( input_file=FLAGS.INPUT, init_model=FLAGS.init_model, restart=FLAGS.restart, finetune=FLAGS.finetune, model_branch=FLAGS.model_branch, use_pretrain_script=FLAGS.use_pretrain_script, skip_neighbor_stat=FLAGS.skip_neighbor_stat, output=FLAGS.output, ) elif FLAGS.command == "freeze": if Path(FLAGS.checkpoint_folder).is_dir(): checkpoint_path = Path(FLAGS.checkpoint_folder) # pt_expt training saves a symlink "model.ckpt.pt" → latest ckpt default_ckpt = checkpoint_path / "model.ckpt.pt" if default_ckpt.exists(): FLAGS.model = str(default_ckpt) else: raise FileNotFoundError( f"Cannot find checkpoint in '{checkpoint_path}'. " "Expected 'model.ckpt.pt' (created by pt_expt training)." ) else: model_path = Path(FLAGS.checkpoint_folder) if not model_path.exists(): raise FileNotFoundError( f"Checkpoint path '{model_path}' does not exist." ) FLAGS.model = str(model_path) if not FLAGS.output.endswith((".pte", ".pt2")): FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")) freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head) elif FLAGS.command == "change-bias": change_bias( input_file=FLAGS.INPUT, mode=FLAGS.mode, bias_value=FLAGS.bias_value, datafile=FLAGS.datafile, system=FLAGS.system, numb_batch=FLAGS.numb_batch, model_branch=FLAGS.model_branch, output=FLAGS.output, ) elif FLAGS.command == "compress": from deepmd.pt_expt.entrypoints.compress import ( enable_compression, ) if not FLAGS.input.endswith((".pte", ".pt2")): FLAGS.input = str(Path(FLAGS.input).with_suffix(".pte")) if not FLAGS.output.endswith((".pte", ".pt2")): FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")) enable_compression( input_file=FLAGS.input, output=FLAGS.output, stride=FLAGS.step, extrapolate=FLAGS.extrapolate, check_frequency=FLAGS.frequency, training_script=FLAGS.training_script, ) else: raise RuntimeError( f"Unsupported command '{FLAGS.command}' for the pt_expt backend." )