Source code for deepmd.jax.entrypoints.main
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeePMD-Kit entry point module."""
import argparse
from pathlib import (
Path,
)
from deepmd.jax.entrypoints.freeze import (
freeze,
)
from deepmd.jax.entrypoints.train import (
train,
)
from deepmd.loggers.loggers import (
set_log_handles,
)
from deepmd.main import (
parse_args,
)
__all__ = ["main"]
[docs]
def main(args: list[str] | argparse.Namespace | None = None) -> None:
"""DeePMD-Kit entry point.
Parameters
----------
args : list[str] or argparse.Namespace, optional
list of command line arguments, used to avoid calling from the subprocess,
as it is quite slow to import tensorflow; if Namespace is given, it will
be used directly
Raises
------
RuntimeError
if no command was input
"""
if not isinstance(args, argparse.Namespace):
args = parse_args(args=args)
dict_args = vars(args)
set_log_handles(
args.log_level,
Path(args.log_path) if args.log_path else None,
mpi_log=None,
)
if args.command == "train":
train(**dict_args)
elif args.command == "freeze":
freeze(**dict_args)
elif args.command is None:
pass
else:
raise RuntimeError(f"unknown command {args.command}")