Source code for deepmd.common

# SPDX-License-Identifier: LGPL-3.0-or-later
import glob
import json
import os
import platform
import shutil
import warnings
from hashlib import (
    sha1,
)
from pathlib import (
    Path,
)
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Set,
    TypeVar,
    Union,
    get_args,
)

try:
    from typing import Literal  # python >=3.8
except ImportError:
    from typing_extensions import Literal  # type: ignore

import numpy as np
import yaml

from deepmd.env import (
    GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.path import (
    DPPath,
)

__all__ = [
    "data_requirement",
    "add_data_requirement",
    "select_idx_map",
    "make_default_mesh",
    "j_must_have",
    "j_loader",
    "expand_sys_str",
    "get_np_precision",
    "VALID_PRECISION",
    "VALID_ACTIVATION",
]

_PRECISION = Literal["default", "float16", "float32", "float64"]
_ACTIVATION = Literal[
    "relu",
    "relu6",
    "softplus",
    "sigmoid",
    "tanh",
    "gelu",
    "gelu_tf",
    "none",
    "linear",
]
# get_args is new in py38
[docs] VALID_PRECISION: Set[_PRECISION] = set(get_args(_PRECISION))
[docs] VALID_ACTIVATION: Set[_ACTIVATION] = set(get_args(_ACTIVATION))
if TYPE_CHECKING: _DICT_VAL = TypeVar("_DICT_VAL") __all__.extend( [ "_DICT_VAL", "_PRECISION", "_ACTIVATION", ] ) # TODO: refactor data_requirement to make it not a global variable # this is not a good way to do things. This is some global variable to which # anyone can write and there is no good way to keep track of the changes
[docs] data_requirement = {}
[docs] def add_data_requirement( key: str, ndof: int, atomic: bool = False, must: bool = False, high_prec: bool = False, type_sel: Optional[bool] = None, repeat: int = 1, default: float = 0.0, dtype: Optional[np.dtype] = None, output_natoms_for_type_sel: bool = False, ): """Specify data requirements for training. Parameters ---------- key : str type of data stored in corresponding `*.npy` file e.g. `forces` or `energy` ndof : int number of the degrees of freedom, this is tied to `atomic` parameter e.g. forces have `atomic=True` and `ndof=3` atomic : bool, optional specifies whwther the `ndof` keyworrd applies to per atom quantity or not, by default False must : bool, optional specifi if the `*.npy` data file must exist, by default False high_prec : bool, optional if true load data to `np.float64` else `np.float32`, by default False type_sel : bool, optional select only certain type of atoms, by default None repeat : int, optional if specify repaeat data `repeat` times, by default 1 default : float, optional, default=0. default value of data dtype : np.dtype, optional the dtype of data, overwrites `high_prec` if provided output_natoms_for_type_sel : bool, optional if True and type_sel is True, the atomic dimension will be natoms instead of nsel """ data_requirement[key] = { "ndof": ndof, "atomic": atomic, "must": must, "high_prec": high_prec, "type_sel": type_sel, "repeat": repeat, "default": default, "dtype": dtype, "output_natoms_for_type_sel": output_natoms_for_type_sel, }
[docs] def select_idx_map(atom_types: np.ndarray, select_types: np.ndarray) -> np.ndarray: """Build map of indices for element supplied element types from all atoms list. Parameters ---------- atom_types : np.ndarray array specifing type for each atoms as integer select_types : np.ndarray types of atoms you want to find indices for Returns ------- np.ndarray indices of types of atoms defined by `select_types` in `atom_types` array Warnings -------- `select_types` array will be sorted before finding indices in `atom_types` """ sort_select_types = np.sort(select_types) idx_map = [] for ii in sort_select_types: idx_map.append(np.where(atom_types == ii)[0]) return np.concatenate(idx_map)
[docs] def make_default_mesh(pbc: bool, mixed_type: bool) -> np.ndarray: """Make mesh. Only the size of mesh matters, not the values: * 6 for PBC, no mixed types * 0 for no PBC, no mixed types * 7 for PBC, mixed types * 1 for no PBC, mixed types Parameters ---------- pbc : bool if True, the mesh will be made for periodic boundary conditions mixed_type : bool if True, the mesh will be made for mixed types Returns ------- np.ndarray mesh """ mesh_size = int(pbc) * 6 + int(mixed_type) default_mesh = np.zeros(mesh_size, dtype=np.int32) return default_mesh
# TODO: rename j_must_have to j_deprecated and only warn about deprecated keys # maybe rename this to j_deprecated and only warn about deprecated keys, # if the deprecated_key argument is left empty function puppose is only custom # error since dict[key] already raises KeyError when the key is missing
[docs] def j_must_have( jdata: Dict[str, "_DICT_VAL"], key: str, deprecated_key: List[str] = [] ) -> "_DICT_VAL": """Assert that supplied dictionary conaines specified key. Returns ------- _DICT_VAL value that was store unde supplied key Raises ------ RuntimeError if the key is not present """ if key not in jdata.keys(): for ii in deprecated_key: if ii in jdata.keys(): warnings.warn(f"the key {ii} is deprecated, please use {key} instead") return jdata[ii] else: raise RuntimeError(f"json database must provide key {key}") else: return jdata[key]
[docs] def j_loader(filename: Union[str, Path]) -> Dict[str, Any]: """Load yaml or json settings file. Parameters ---------- filename : Union[str, Path] path to file Returns ------- Dict[str, Any] loaded dictionary Raises ------ TypeError if the supplied file is of unsupported type """ filepath = Path(filename) if filepath.suffix.endswith("json"): with filepath.open() as fp: return json.load(fp) elif filepath.suffix.endswith(("yml", "yaml")): with filepath.open() as fp: return yaml.safe_load(fp) else: raise TypeError("config file must be json, or yaml/yml")
# TODO port expand_sys_str completely to pathlib when all callers are ported
[docs] def expand_sys_str(root_dir: Union[str, Path]) -> List[str]: """Recursively iterate over directories taking those that contain `type.raw` file. Parameters ---------- root_dir : Union[str, Path] starting directory Returns ------- List[str] list of string pointing to system directories """ root_dir = DPPath(root_dir) matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()] if (root_dir / "type.raw").is_file(): matches.append(str(root_dir)) return matches
[docs] def get_np_precision(precision: "_PRECISION") -> np.dtype: """Get numpy precision constant from string. Parameters ---------- precision : _PRECISION string name of numpy constant or default Returns ------- np.dtype numpy presicion constant Raises ------ RuntimeError if string is invalid """ if precision == "default": return GLOBAL_NP_FLOAT_PRECISION elif precision == "float16": return np.float16 elif precision == "float32": return np.float32 elif precision == "float64": return np.float64 else: raise RuntimeError(f"{precision} is not a valid precision")
def symlink_prefix_files(old_prefix: str, new_prefix: str): """Create symlinks from old checkpoint prefix to new one. On Windows this function will copy files instead of creating symlinks. Parameters ---------- old_prefix : str old checkpoint prefix, all files with this prefix will be symlinked new_prefix : str new checkpoint prefix """ original_files = glob.glob(old_prefix + ".*") for ori_ff in original_files: new_ff = new_prefix + ori_ff[len(old_prefix) :] try: # remove old one os.remove(new_ff) except OSError: pass if platform.system() != "Windows": # by default one does not have access to create symlink on Windows os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) else: shutil.copyfile(ori_ff, new_ff) def get_hash(obj) -> str: """Get hash of object. Parameters ---------- obj object to hash """ return sha1(json.dumps(obj).encode("utf-8")).hexdigest() def j_get_type(data: dict, class_name: str = "object") -> str: """Get the type from the data. Parameters ---------- data : dict the data class_name : str, optional the name of the class for error message, by default "object" Returns ------- str the type """ try: return data["type"] except KeyError as e: raise KeyError(f"the type of the {class_name} should be set by `type`") from e