Source code for deepmd.pt.utils.neighbor_stat

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Iterator,
    Optional,
    Tuple,
)

import numpy as np
import torch

from deepmd.pt.utils.auto_batch_size import (
    AutoBatchSize,
)
from deepmd.pt.utils.env import (
    DEVICE,
)
from deepmd.pt.utils.nlist import (
    extend_coord_with_ghosts,
)
from deepmd.utils.data_system import (
    DeepmdDataSystem,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat


[docs] class NeighborStatOP(torch.nn.Module): """Class for getting neighbor statics data information. Parameters ---------- ntypes The num of atom types rcut The cut-off radius mixed_types : bool, optional If True, treat neighbors of all types as a single type. """ def __init__( self, ntypes: int, rcut: float, mixed_types: bool, ) -> None: super().__init__() self.rcut = rcut self.ntypes = ntypes self.mixed_types = mixed_types
[docs] def forward( self, coord: torch.Tensor, atype: torch.Tensor, cell: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Calculate the neareest neighbor distance between atoms, maximum nbor size of atoms and the output data range of the environment matrix. Parameters ---------- coord The coordinates of atoms. atype The atom types. cell The cell. Returns ------- torch.Tensor The minimal squared distance between two atoms, in the shape of (nframes,) torch.Tensor The maximal number of neighbors """ nframes = coord.shape[0] coord = coord.view(nframes, -1, 3) nloc = coord.shape[1] coord = coord.view(nframes, nloc * 3) extend_coord, extend_atype, _ = extend_coord_with_ghosts( coord, atype, cell, self.rcut ) coord1 = extend_coord.reshape(nframes, -1) nall = coord1.shape[1] // 3 coord0 = coord1[:, : nloc * 3] diff = ( coord1.reshape([nframes, -1, 3])[:, None, :, :] - coord0.reshape([nframes, -1, 3])[:, :, None, :] ) assert list(diff.shape) == [nframes, nloc, nall, 3] # remove the diagonal elements mask = torch.eye(nloc, nall, dtype=torch.bool, device=diff.device) diff[:, mask] = torch.inf rr2 = torch.sum(torch.square(diff), dim=-1) min_rr2, _ = torch.min(rr2, dim=-1) # count the number of neighbors if not self.mixed_types: mask = rr2 < self.rcut**2 nnei = torch.zeros( (nframes, nloc, self.ntypes), dtype=torch.int32, device=mask.device ) for ii in range(self.ntypes): nnei[:, :, ii] = torch.sum( mask & extend_atype.eq(ii)[:, None, :], dim=-1 ) else: mask = rr2 < self.rcut**2 # virtual types (<0) are not counted nnei = torch.sum(mask & extend_atype.ge(0)[:, None, :], dim=-1).view( nframes, nloc, 1 ) max_nnei, _ = torch.max(nnei, dim=1) return min_rr2, max_nnei
[docs] class NeighborStat(BaseNeighborStat): """Neighbor statistics using pure NumPy. Parameters ---------- ntypes : int The num of atom types rcut : float The cut-off radius mixed_type : bool, optional, default=False Treat all types as a single type. """ def __init__( self, ntypes: int, rcut: float, mixed_type: bool = False, ) -> None: super().__init__(ntypes, rcut, mixed_type) op = NeighborStatOP(ntypes, rcut, mixed_type) self.op = torch.jit.script(op) self.auto_batch_size = AutoBatchSize()
[docs] def iterator( self, data: DeepmdDataSystem ) -> Iterator[Tuple[np.ndarray, float, str]]: """Abstract method for producing data. Yields ------ np.ndarray The maximal number of neighbors float The squared minimal distance between two atoms str The directory of the data system """ for ii in range(len(data.system_dirs)): for jj in data.data_systems[ii].dirs: data_set = data.data_systems[ii] data_set_data = data_set._load_set(jj) minrr2, max_nnei = self.auto_batch_size.execute_all( self._execute, data_set_data["coord"].shape[0], data_set.get_natoms(), data_set_data["coord"], data_set_data["type"], data_set_data["box"] if data_set.pbc else None, ) yield np.max(max_nnei, axis=0), np.min(minrr2), jj
[docs] def _execute( self, coord: np.ndarray, atype: np.ndarray, cell: Optional[np.ndarray], ): """Execute the operation. Parameters ---------- coord The coordinates of atoms. atype The atom types. cell The cell. """ minrr2, max_nnei = self.op( torch.from_numpy(coord).to(DEVICE), torch.from_numpy(atype).to(DEVICE), torch.from_numpy(cell).to(DEVICE) if cell is not None else None, ) minrr2 = minrr2.detach().cpu().numpy() max_nnei = max_nnei.detach().cpu().numpy() return minrr2, max_nnei