# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import json
import logging
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
from typing import (
List,
Optional,
Union,
)
import h5py
import torch
import torch.distributed as dist
import torch.version
from torch.distributed.elastic.multiprocessing.errors import (
record,
)
from deepmd import (
__version__,
)
from deepmd.loggers.loggers import (
set_log_handles,
)
from deepmd.main import (
parse_args,
)
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.infer import (
inference,
)
from deepmd.pt.model.model import (
BaseModel,
)
from deepmd.pt.train import (
training,
)
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)
from deepmd.pt.utils.env import (
DEVICE,
)
from deepmd.pt.utils.finetune import (
change_finetune_model_params,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
process_systems,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter
[docs]
log = logging.getLogger(__name__)
[docs]
def get_trainer(
config,
init_model=None,
restart_model=None,
finetune_model=None,
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# Initialize DDP
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")
ckpt = init_model if init_model is not None else restart_model
finetune_links = None
if finetune_model is not None:
config["model"], finetune_links = change_finetune_model_params(
finetune_model,
config["model"],
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0
):
training_dataset_params = data_dict_single["training_data"]
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
training_systems = process_systems(training_systems)
if validation_systems is not None:
validation_systems = process_systems(validation_systems)
# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if rank != 0:
stat_file_path_single = None
elif stat_file_path_single is not None:
if not Path(stat_file_path_single).exists():
if stat_file_path_single.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_path_single, "w") as f:
pass
else:
Path(stat_file_path_single).mkdir()
stat_file_path_single = DPPath(stat_file_path_single, "a")
# validation and training data
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single["type_map"],
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single["type_map"],
)
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single["type_map"],
)
return (
train_data_single,
validation_data_single,
stat_file_path_single,
)
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
if not multi_task:
(
train_data,
validation_data,
stat_file_path,
) = prepare_trainer_input_single(
config["model"],
config["training"],
config["loss"],
rank=rank,
)
else:
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
config["loss_dict"][model_key],
suffix=f"_{model_key}",
rank=rank,
)
trainer = training.Trainer(
config,
train_data,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
return trainer
[docs]
class SummaryPrinter(BaseSummaryPrinter):
"""Summary printer for PyTorch."""
[docs]
def is_built_with_cuda(self) -> bool:
"""Check if the backend is built with CUDA."""
return torch.version.cuda is not None
[docs]
def is_built_with_rocm(self) -> bool:
"""Check if the backend is built with ROCm."""
return torch.version.hip is not None
[docs]
def get_compute_device(self) -> str:
"""Get Compute device."""
return str(DEVICE)
[docs]
def get_ngpus(self) -> int:
"""Get the number of GPUs."""
return torch.cuda.device_count()
[docs]
def get_backend_info(self) -> dict:
"""Get backend information."""
return {
"Backend": "PyTorch",
"PT ver": f"v{torch.__version__}-g{torch.version.git_version[:11]}",
"Enable custom OP": ENABLE_CUSTOMIZED_OP,
}
[docs]
def train(FLAGS):
log.info("Configuration path: %s", FLAGS.INPUT)
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)
# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)
# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
if not multi_task:
config["model"] = BaseModel.update_sel(config, config["model"])
else:
training_jdata = deepcopy(config["training"])
training_jdata.pop("data_dict", {})
training_jdata.pop("model_prob", {})
for model_item in config["model"]["model_dict"]:
fake_global_jdata = {
"model": deepcopy(config["model"]["model_dict"][model_item]),
"training": deepcopy(config["training"]["data_dict"][model_item]),
}
fake_global_jdata["training"].update(training_jdata)
config["model"]["model_dict"][model_item] = BaseModel.update_sel(
fake_global_jdata, config["model"]["model_dict"][model_item]
)
with open(FLAGS.output, "w") as fp:
json.dump(config, fp, indent=4)
trainer = get_trainer(
config,
FLAGS.init_model,
FLAGS.restart,
FLAGS.finetune,
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
)
trainer.run()
[docs]
def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.model_def_script:
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
torch.jit.save(
model,
FLAGS.output,
extra_files,
)
@record
[docs]
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
FLAGS = parse_args(args=args)
else:
FLAGS = args
set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None)
log.debug("Log handles were successfully set")
log.info("DeepMD version: %s", __version__)
if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))
else:
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")
if __name__ == "__main__":
main()