Source code for deepmd.tf.utils.batch_size

# SPDX-License-Identifier: LGPL-3.0-or-later
import os

from packaging.version import (
    Version,
)

from deepmd.tf.env import (
    TF_VERSION,
    tf,
)
from deepmd.tf.utils.errors import (
    OutOfMemoryError,
)
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase
from deepmd.utils.batch_size import (
    log,
)


[docs] class AutoBatchSize(AutoBatchSizeBase): def __init__( self, initial_batch_size: int = 1024, factor: float = 2.0, *, silent: bool = False, ) -> None: super().__init__(initial_batch_size, factor, silent=silent) DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) if not DP_INFER_BATCH_SIZE > 0 and not self.silent: if self.is_gpu_available(): log.info( "If you encounter the error 'an illegal memory access was encountered', this may be due to a TensorFlow issue. " "To avoid this, set the environment variable DP_INFER_BATCH_SIZE to a smaller value than the last adjusted batch size. " "The environment variable DP_INFER_BATCH_SIZE controls the inference batch size (nframes * natoms). " )
[docs] def is_gpu_available(self) -> bool: """Check if GPU is available. Returns ------- bool True if GPU is available """ return ( Version(TF_VERSION) >= Version("1.14") and tf.config.experimental.get_visible_devices("GPU") ) or tf.test.is_gpu_available()
[docs] def is_oom_error(self, e: Exception) -> bool: """Check if the exception is an OOM error. Parameters ---------- e : Exception Exception """ return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))