Source code for deepmd.dpmodel.utils.dist_check

# SPDX-License-Identifier: LGPL-3.0-or-later
"""Minimum pairwise distance check for frame validity filtering."""

from __future__ import (
    annotations,
)

import numpy as np

[docs] _MIN_PAIR_DIST_BLOCK_PAIRS = 262_144
[docs] def compute_min_pair_dist_single( coord: np.ndarray, box: np.ndarray | None, atype: np.ndarray, stop_below: float | None = None, ) -> float: """Compute the minimum pairwise atomic distance for a single frame. Parameters ---------- coord : np.ndarray Atomic coordinates, flattened with shape (natoms * 3,) or reshaped as (natoms, 3). box : np.ndarray or None Box vectors with shape (9,) for PBC, or None for non-PBC. atype : np.ndarray Atom types with shape (natoms,). Virtual atoms (type < 0) are excluded from the distance check. stop_below : float or None Optional early-stop threshold. If a block has any pair closer than this value, the block minimum is returned immediately. Returns ------- float Minimum pairwise distance. Returns inf if fewer than 2 real atoms exist. """ coord = coord.reshape(-1, 3) # === Step 1. Filter out virtual atoms === real_mask = atype.ravel() >= 0 real_coord = coord[real_mask] n_real = real_coord.shape[0] if n_real < 2: return float("inf") # === Step 2. Prepare minimum image convention for PBC === if box is not None: cell = box.reshape(3, 3) inv_cell = np.linalg.inv(cell) else: cell = None inv_cell = None # === Step 3. Compute distances in bounded row blocks === block_size = max(1, min(n_real, _MIN_PAIR_DIST_BLOCK_PAIRS // n_real)) min_dist_sq = float("inf") stop_dist_sq = ( float(stop_below) * float(stop_below) if stop_below is not None and stop_below > 0.0 else None ) for start in range(0, n_real, block_size): stop = min(start + block_size, n_real) diff = real_coord[np.newaxis, :, :] - real_coord[start:stop, np.newaxis, :] if cell is not None and inv_cell is not None: frac_diff = diff @ inv_cell frac_diff -= np.round(frac_diff) diff = frac_diff @ cell dist_sq = np.sum(diff * diff, axis=-1) rows = np.arange(stop - start, dtype=np.int64) dist_sq[rows, start + rows] = np.inf min_dist_sq = min(min_dist_sq, float(dist_sq.min())) if min_dist_sq == 0.0 or ( stop_dist_sq is not None and min_dist_sq < stop_dist_sq ): break return float(np.sqrt(min_dist_sq))