Source code for deepmd.pt.utils.exclude_mask

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
    Set,
    Tuple,
)

import numpy as np
import torch

from deepmd.pt.utils.utils import (
    to_torch_tensor,
)


[docs] class AtomExcludeMask(torch.nn.Module): """Computes the type exclusion mask for atoms.""" def __init__( self, ntypes: int, exclude_types: List[int] = [], ): super().__init__() self.reinit(ntypes, exclude_types)
[docs] def reinit( self, ntypes: int, exclude_types: List[int] = [], ): self.ntypes = ntypes self.exclude_types = exclude_types self.type_mask = np.array( [1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)], dtype=np.int32, ) self.type_mask = to_torch_tensor(self.type_mask).view([-1])
[docs] def get_exclude_types(self): return self.exclude_types
[docs] def get_type_mask(self): return self.type_mask
[docs] def forward( self, atype: torch.Tensor, ) -> torch.Tensor: """Compute type exclusion mask for atoms. Parameters ---------- atype The extended atom types. shape: nf x natom Returns ------- mask The type exclusion mask for atoms. shape: nf x natom Element [ff,ii] being 0 if type(ii) is excluded, otherwise being 1. """ nf, natom = atype.shape return self.type_mask[atype].view(nf, natom)
[docs] class PairExcludeMask(torch.nn.Module): """Computes the type exclusion mask for atom pairs.""" def __init__( self, ntypes: int, exclude_types: List[Tuple[int, int]] = [], ): super().__init__() self.reinit(ntypes, exclude_types)
[docs] def reinit( self, ntypes: int, exclude_types: List[Tuple[int, int]] = [], ): self.ntypes = ntypes self._exclude_types: Set[Tuple[int, int]] = set() for tt in exclude_types: assert len(tt) == 2 self._exclude_types.add((tt[0], tt[1])) self._exclude_types.add((tt[1], tt[0])) # ntypes + 1 for nlist masks self.type_mask = np.array( [ [ 1 if (tt_i, tt_j) not in self._exclude_types else 0 for tt_i in range(ntypes + 1) ] for tt_j in range(ntypes + 1) ], dtype=np.int32, ) # (ntypes+1 x ntypes+1) self.type_mask = to_torch_tensor(self.type_mask).view([-1]) self.no_exclusion = len(self._exclude_types) == 0
[docs] def get_exclude_types(self): return self._exclude_types
# may have a better place for this method...
[docs] def forward( self, nlist: torch.Tensor, atype_ext: torch.Tensor, ) -> torch.Tensor: """Compute type exclusion mask. Parameters ---------- nlist The neighbor list. shape: nf x nloc x nnei atype_ext The extended aotm types. shape: nf x nall Returns ------- mask The type exclusion mask of shape: nf x nloc x nnei. Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, otherwise being 1. """ if self.no_exclusion: # safely return 1 if nothing is excluded. return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device) nf, nloc, nnei = nlist.shape nall = atype_ext.shape[1] # add virtual atom of type ntypes. nf x nall+1 ae = torch.cat( [ atype_ext, self.ntypes * torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device), ], dim=-1, ) type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.ntypes + 1) # nf x nloc x nnei index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei) type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) type_ij = type_ij.view(nf, nloc * nnei) mask = self.type_mask[type_ij].view(nf, nloc, nnei) return mask