Source code for deepmd.dpmodel.utils.exclude_mask

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
    Tuple,
)

import numpy as np


[docs] class AtomExcludeMask: """Computes the type exclusion mask for atoms.""" def __init__( self, ntypes: int, exclude_types: List[int] = [], ): self.ntypes = ntypes self.exclude_types = exclude_types self.type_mask = np.array( [1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)], dtype=np.int32, ) # (ntypes) self.type_mask = self.type_mask.reshape([-1])
[docs] def get_exclude_types(self): return self.exclude_types
[docs] def get_type_mask(self): return self.type_mask
[docs] def build_type_exclude_mask( self, atype: np.ndarray, ): """Compute type exclusion mask for atoms. Parameters ---------- atype The extended aotm types. shape: nf x natom Returns ------- mask The type exclusion mask for atoms. shape: nf x natom Element [ff,ii] being 0 if type(ii) is excluded, otherwise being 1. """ nf, natom = atype.shape return self.type_mask[atype].reshape(nf, natom)
[docs] class PairExcludeMask: """Computes the type exclusion mask for atom pairs.""" def __init__( self, ntypes: int, exclude_types: List[Tuple[int, int]] = [], ): self.ntypes = ntypes self.exclude_types = set() for tt in exclude_types: assert len(tt) == 2 self.exclude_types.add((tt[0], tt[1])) self.exclude_types.add((tt[1], tt[0])) # ntypes + 1 for nlist masks self.type_mask = np.array( [ [ 1 if (tt_i, tt_j) not in self.exclude_types else 0 for tt_i in range(ntypes + 1) ] for tt_j in range(ntypes + 1) ], dtype=np.int32, ) # (ntypes+1 x ntypes+1) self.type_mask = self.type_mask.reshape([-1])
[docs] def get_exclude_types(self): return self.exclude_types
[docs] def build_type_exclude_mask( self, nlist: np.ndarray, atype_ext: np.ndarray, ): """Compute type exclusion mask for atom pairs. Parameters ---------- nlist The neighbor list. shape: nf x nloc x nnei atype_ext The extended aotm types. shape: nf x nall Returns ------- mask The type exclusion mask for pair atoms of shape: nf x nloc x nnei. Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, otherwise being 1. """ if len(self.exclude_types) == 0: # safely return 1 if nothing is excluded. return np.ones_like(nlist, dtype=np.int32) nf, nloc, nnei = nlist.shape nall = atype_ext.shape[1] # add virtual atom of type ntypes. nf x nall+1 ae = np.concatenate( [atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 ) type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1) # nf x nloc x nnei index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei) type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) type_ij = type_ij.reshape(nf, nloc * nnei) mask = self.type_mask[type_ij].reshape(nf, nloc, nnei) return mask
[docs] def __contains__(self, item): return item in self.exclude_types