Source code for deepmd.pt.utils.env

# SPDX-License-Identifier: LGPL-3.0-or-later
import os

import numpy as np
import torch

from deepmd.common import (
    VALID_PRECISION,
)
from deepmd.env import (
    GLOBAL_ENER_FLOAT_PRECISION,
    GLOBAL_NP_FLOAT_PRECISION,
    get_default_nthreads,
    set_default_nthreads,
)

[docs] SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
try: # only linux ncpus = len(os.sched_getaffinity(0)) except AttributeError: ncpus = os.cpu_count()
[docs] NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(8, ncpus)))
# Make sure DDP uses correct device if applicable LOCAL_RANK = os.environ.get("LOCAL_RANK")
[docs] LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False:
[docs] DEVICE = torch.device("cpu")
else: DEVICE = torch.device(f"cuda:{LOCAL_RANK}")
[docs] JIT = False
[docs] CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
[docs] ENERGY_BIAS_TRAINABLE = True
[docs] PRECISION_DICT = { "float16": torch.float16, "float32": torch.float32, "float64": torch.float64, "half": torch.float16, "single": torch.float32, "double": torch.float64, "int32": torch.int32, "int64": torch.int64, "bfloat16": torch.bfloat16, }
[docs] GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name]
[docs] GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name ]
PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) # cannot automatically generated
[docs] RESERVED_PRECISON_DICT = { torch.float16: "float16", torch.float32: "float32", torch.float64: "float64", torch.int32: "int32", torch.int64: "int64", torch.bfloat16: "bfloat16", }
assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISON_DICT.keys())
[docs] DEFAULT_PRECISION = "float64"
# throw warnings if threads not set set_default_nthreads() inter_nthreads, intra_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: torch.set_num_threads(intra_nthreads) __all__ = [ "GLOBAL_ENER_FLOAT_PRECISION", "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_PT_FLOAT_PRECISION", "GLOBAL_PT_ENER_FLOAT_PRECISION", "DEFAULT_PRECISION", "PRECISION_DICT", "RESERVED_PRECISON_DICT", "SAMPLER_RECORD", "NUM_WORKERS", "DEVICE", "JIT", "CACHE_PER_SYS", "ENERGY_BIAS_TRAINABLE", "LOCAL_RANK", ]