Source code for deepmd.pt.utils.preprocess

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
    Union,
)

import torch

from deepmd.pt.utils import (
    env,
)

[docs] log = logging.getLogger(__name__)
[docs] class Region3D: def __init__(self, boxt): """Construct a simulation box.""" boxt = boxt.reshape([3, 3]) self.boxt = boxt # convert physical coordinates to internal ones self.rec_boxt = torch.linalg.inv( self.boxt ) # convert internal coordinates to physical ones self.volume = torch.linalg.det(self.boxt) # compute the volume # boxt = boxt.permute(1, 0) c_yz = torch.cross(boxt[1], boxt[2]) self._h2yz = self.volume / torch.linalg.norm(c_yz) c_zx = torch.cross(boxt[2], boxt[0]) self._h2zx = self.volume / torch.linalg.norm(c_zx) c_xy = torch.cross(boxt[0], boxt[1]) self._h2xy = self.volume / torch.linalg.norm(c_xy)
[docs] def phys2inter(self, coord): """Convert physical coordinates to internal ones.""" return coord @ self.rec_boxt
[docs] def inter2phys(self, coord): """Convert internal coordinates to physical ones.""" return coord @ self.boxt
[docs] def get_face_distance(self): """Return face distinces to each surface of YZ, ZX, XY.""" return torch.stack([self._h2yz, self._h2zx, self._h2xy])
[docs] def normalize_coord(coord, region: Region3D, nloc: int): """Move outer atoms into region by mirror. Args: - coord: shape is [nloc*3] """ tmp_coord = coord.clone() inter_cood = torch.remainder(region.phys2inter(tmp_coord), 1.0) tmp_coord = region.inter2phys(inter_cood) return tmp_coord
[docs] def compute_serial_cid(cell_offset, ncell): """Tell the sequential cell ID in its 3D space. Args: - cell_offset: shape is [3] - ncell: shape is [3] """ cell_offset[:, 0] *= ncell[1] * ncell[2] cell_offset[:, 1] *= ncell[2] return cell_offset.sum(-1)
[docs] def compute_pbc_shift(cell_offset, ncell): """Tell shift count to move the atom into region.""" shift = torch.zeros_like(cell_offset) shift = shift + (cell_offset < 0) * -( torch.div(cell_offset, ncell, rounding_mode="floor") ) shift = shift + (cell_offset >= ncell) * -( torch.div((cell_offset - ncell), ncell, rounding_mode="floor") + 1 ) assert torch.all(cell_offset + shift * ncell >= 0) assert torch.all(cell_offset + shift * ncell < ncell) return shift
[docs] def build_inside_clist(coord, region: Region3D, ncell): """Build cell list on atoms inside region. Args: - coord: shape is [nloc*3] - ncell: shape is [3] """ loc_ncell = int(torch.prod(ncell)) # num of local cells nloc = coord.numel() // 3 # num of local atoms inter_cell_size = 1.0 / ncell inter_cood = region.phys2inter(coord.view(-1, 3)) cell_offset = torch.floor(inter_cood / inter_cell_size).to(torch.long) # numerical error brought by conversion from phys to inter back and force # may lead to negative value cell_offset[cell_offset < 0] = 0 delta = cell_offset - ncell a2c = compute_serial_cid(cell_offset, ncell) # cell id of atoms arange = torch.arange(0, loc_ncell, 1) cellid = a2c == arange.unsqueeze(-1) # one hot cellid c2a = cellid.nonzero() lst = [] cnt = 0 bincount = torch.bincount(a2c, minlength=loc_ncell) for i in range(loc_ncell): n = bincount[i] lst.append(c2a[cnt : cnt + n, 1]) cnt += n return a2c, lst
[docs] def append_neighbors(coord, region: Region3D, atype, rcut: float): """Make ghost atoms who are valid neighbors. Args: - coord: shape is [nloc*3] - atype: shape is [nloc] """ to_face = region.get_face_distance() # compute num and size of local cells ncell = torch.floor(to_face / rcut).to(torch.long) ncell[ncell == 0] = 1 cell_size = to_face / ncell ngcell = ( torch.floor(rcut / cell_size).to(torch.long) + 1 ) # num of cells out of local, which contain ghost atoms # add ghost atoms a2c, c2a = build_inside_clist(coord, region, ncell) xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1) yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1) zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1) xyz = xi.view(-1, 1, 1, 1) * torch.tensor([1, 0, 0], dtype=torch.long) xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor([0, 1, 0], dtype=torch.long) xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor([0, 0, 1], dtype=torch.long) xyz = xyz.view(-1, 3) mask_a = (xyz >= 0).all(dim=-1) mask_b = (xyz < ncell).all(dim=-1) mask = ~torch.logical_and(mask_a, mask_b) xyz = xyz[mask] # cell coord shift = compute_pbc_shift(xyz, ncell) coord_shift = region.inter2phys(shift.to(env.GLOBAL_PT_FLOAT_PRECISION)) mirrored = shift * ncell + xyz cid = compute_serial_cid(mirrored, ncell) n_atoms = coord.shape[0] aid = [c2a[ci] + i * n_atoms for i, ci in enumerate(cid)] aid = torch.cat(aid) tmp = torch.div(aid, n_atoms, rounding_mode="trunc") aid = aid % n_atoms tmp_coord = coord[aid] - coord_shift[tmp] tmp_atype = atype[aid] # merge local and ghost atoms merged_coord = torch.cat([coord, tmp_coord]) merged_coord_shift = torch.cat([torch.zeros_like(coord), coord_shift[tmp]]) merged_atype = torch.cat([atype, tmp_atype]) merged_mapping = torch.cat([torch.arange(atype.numel()), aid]) return merged_coord_shift, merged_atype, merged_mapping
[docs] def build_neighbor_list( nloc: int, coord, atype, rcut: float, sec, mapping, type_split=True, min_check=False ): """For each atom inside region, build its neighbor list. Args: - coord: shape is [nall*3] - atype: shape is [nall] """ nall = coord.numel() // 3 coord = coord.float() nlist = [[] for _ in range(nloc)] coord_l = coord.view(-1, 1, 3)[:nloc] coord_r = coord.view(1, -1, 3) distance = coord_l - coord_r distance = torch.linalg.norm(distance, dim=-1) DISTANCE_INF = distance.max().detach() + rcut distance[:nloc, :nloc] += torch.eye(nloc, dtype=torch.bool) * DISTANCE_INF if min_check: if distance.min().abs() < 1e-6: raise RuntimeError("Atom dist too close!") if not type_split: sec = sec[-1:] lst = [] nlist = torch.zeros((nloc, sec[-1].item())).long() - 1 nlist_loc = torch.zeros((nloc, sec[-1].item())).long() - 1 nlist_type = torch.zeros((nloc, sec[-1].item())).long() - 1 for i, nnei in enumerate(sec): if i > 0: nnei = nnei - sec[i - 1] if not type_split: tmp = distance else: mask = atype.unsqueeze(0) == i tmp = distance + (~mask) * DISTANCE_INF if tmp.shape[1] >= nnei: _sorted, indices = torch.topk(tmp, nnei, dim=1, largest=False) else: # when nnei > nall indices = torch.zeros((nloc, nnei)).long() - 1 _sorted = torch.ones((nloc, nnei)).long() * DISTANCE_INF _sorted_nnei, indices_nnei = torch.topk( tmp, tmp.shape[1], dim=1, largest=False ) _sorted[:, : tmp.shape[1]] = _sorted_nnei indices[:, : tmp.shape[1]] = indices_nnei mask = (_sorted < rcut).to(torch.long) indices_loc = mapping[indices] indices = indices * mask + -1 * (1 - mask) # -1 for padding indices_loc = indices_loc * mask + -1 * (1 - mask) # -1 for padding if i == 0: start = 0 else: start = sec[i - 1] end = min(sec[i], start + indices.shape[1]) nlist[:, start:end] = indices[:, :nnei] nlist_loc[:, start:end] = indices_loc[:, :nnei] nlist_type[:, start:end] = atype[indices[:, :nnei]] * mask + -1 * (1 - mask) return nlist, nlist_loc, nlist_type
[docs] def compute_smooth_weight(distance, rmin: float, rmax: float): """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") min_mask = distance <= rmin max_mask = distance >= rmax mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask)) uu = (distance - rmin) / (rmax - rmin) vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 return vv * mid_mask + min_mask
[docs] def make_env_mat( coord, atype, region, rcut: Union[float, list], sec, pbc=True, type_split=True, min_check=False, ): """Based on atom coordinates, return environment matrix. Returns ------- nlist: nlist, [nloc, nnei] merged_coord_shift: shift on nall atoms, [nall, 3] merged_mapping: mapping from nall index to nloc index, [nall] """ # move outer atoms into cell hybrid = isinstance(rcut, list) _rcut = rcut if hybrid: _rcut = max(rcut) if pbc: merged_coord_shift, merged_atype, merged_mapping = append_neighbors( coord, region, atype, _rcut ) merged_coord = coord[merged_mapping] - merged_coord_shift if merged_coord.shape[0] <= coord.shape[0]: log.warning("No ghost atom is added for system ") else: merged_coord_shift = torch.zeros_like(coord) merged_atype = atype.clone() merged_mapping = torch.arange(atype.numel()) merged_coord = coord.clone() # build nlist if not hybrid: nlist, nlist_loc, nlist_type = build_neighbor_list( coord.shape[0], merged_coord, merged_atype, rcut, sec, merged_mapping, type_split=type_split, min_check=min_check, ) else: nlist, nlist_loc, nlist_type = [], [], [] for ii, single_rcut in enumerate(rcut): nlist_tmp, nlist_loc_tmp, nlist_type_tmp = build_neighbor_list( coord.shape[0], merged_coord, merged_atype, single_rcut, sec[ii], merged_mapping, type_split=type_split, min_check=min_check, ) nlist.append(nlist_tmp) nlist_loc.append(nlist_loc_tmp) nlist_type.append(nlist_type_tmp) return nlist, nlist_loc, nlist_type, merged_coord_shift, merged_mapping