Source code for deepmd.pt.model.descriptor.repformer_layer
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
)
import torch
from deepmd.pt.model.network.network import (
SimpleLinear,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
[docs]
def torch_linear(*args, **kwargs):
return torch.nn.Linear(
*args, **kwargs, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
[docs]
def _make_nei_g1(
g1_ext: torch.Tensor,
nlist: torch.Tensor,
) -> torch.Tensor:
# nlist: nb x nloc x nnei
nb, nloc, nnei = nlist.shape
# g1_ext: nb x nall x ng1
ng1 = g1_ext.shape[-1]
# index: nb x (nloc x nnei) x ng1
index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1)
# gg1 : nb x (nloc x nnei) x ng1
gg1 = torch.gather(g1_ext, dim=1, index=index)
# gg1 : nb x nloc x nnei x ng1
gg1 = gg1.view(nb, nloc, nnei, ng1)
return gg1
[docs]
def _apply_nlist_mask(
gg: torch.Tensor,
nlist_mask: torch.Tensor,
) -> torch.Tensor:
# gg: nf x nloc x nnei x ng
# msk: nf x nloc x nnei
return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0)
[docs]
def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor:
# gg: nf x nloc x nnei x ng
# sw: nf x nloc x nnei
return gg * sw.unsqueeze(-1)
[docs]
def _apply_h_norm(
hh: torch.Tensor, # nf x nloc x nnei x 3
) -> torch.Tensor:
"""Normalize h by the std of vector length.
do not have an idea if this is a good way.
"""
nf, nl, nnei, _ = hh.shape
# nf x nloc x nnei
normh = torch.linalg.norm(hh, dim=-1)
# nf x nloc
std = torch.std(normh, dim=-1)
# nf x nloc x nnei x 3
hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None])
return hh
[docs]
class Atten2Map(torch.nn.Module):
def __init__(
self,
ni: int,
nd: int,
nh: int,
has_gate: bool = False, # apply gate to attn map
smooth: bool = True,
attnw_shift: float = 20.0,
):
super().__init__()
self.ni = ni
self.nd = nd
self.nh = nh
self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False)
self.has_gate = has_gate
self.smooth = smooth
self.attnw_shift = attnw_shift
[docs]
def forward(
self,
g2: torch.Tensor, # nb x nloc x nnei x ng2
h2: torch.Tensor, # nb x nloc x nnei x 3
nlist_mask: torch.Tensor, # nb x nloc x nnei
sw: torch.Tensor, # nb x nloc x nnei
) -> torch.Tensor:
(
nb,
nloc,
nnei,
_,
) = g2.shape
nd, nh = self.nd, self.nh
# nb x nloc x nnei x nd x (nh x 2)
g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2)
# nb x nloc x (nh x 2) x nnei x nd
g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3))
# nb x nloc x nh x nnei x nd
g2q, g2k = torch.split(g2qk, nh, dim=2)
# g2q = torch.nn.functional.normalize(g2q, dim=-1)
# g2k = torch.nn.functional.normalize(g2k, dim=-1)
# nb x nloc x nh x nnei x nnei
attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5
if self.has_gate:
gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3)
attnw = attnw * gate
# mask the attenmap, nb x nloc x 1 x 1 x nnei
attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2)
# mask the attenmap, nb x nloc x 1 x nnei x 1
attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1)
if self.smooth:
attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[
:, :, None, None, :
] - self.attnw_shift
else:
attnw = attnw.masked_fill(
attnw_mask,
float("-inf"),
)
attnw = torch.softmax(attnw, dim=-1)
attnw = attnw.masked_fill(
attnw_mask,
0.0,
)
# nb x nloc x nh x nnei x nnei
attnw = attnw.masked_fill(
attnw_mask_c,
0.0,
)
if self.smooth:
attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :]
# nb x nloc x nnei x nnei
h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5
# nb x nloc x nh x nnei x nnei
ret = attnw * h2h2t[:, :, None, :, :]
# ret = torch.softmax(g2qk, dim=-1)
# nb x nloc x nnei x nnei x nh
ret = torch.permute(ret, (0, 1, 3, 4, 2))
return ret
[docs]
class Atten2MultiHeadApply(torch.nn.Module):
def __init__(
self,
ni: int,
nh: int,
):
super().__init__()
self.ni = ni
self.nh = nh
self.mapv = SimpleLinear(ni, ni * nh, bias=False)
self.head_map = SimpleLinear(ni * nh, ni)
[docs]
def forward(
self,
AA: torch.Tensor, # nf x nloc x nnei x nnei x nh
g2: torch.Tensor, # nf x nloc x nnei x ng2
) -> torch.Tensor:
nf, nloc, nnei, ng2 = g2.shape
nh = self.nh
# nf x nloc x nnei x ng2 x nh
g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh)
# nf x nloc x nh x nnei x ng2
g2v = torch.permute(g2v, (0, 1, 4, 2, 3))
# g2v = torch.nn.functional.normalize(g2v, dim=-1)
# nf x nloc x nh x nnei x nnei
AA = torch.permute(AA, (0, 1, 4, 2, 3))
# nf x nloc x nh x nnei x ng2
ret = torch.matmul(AA, g2v)
# nf x nloc x nnei x ng2 x nh
ret = torch.permute(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh))
# nf x nloc x nnei x ng2
return self.head_map(ret)
[docs]
class Atten2EquiVarApply(torch.nn.Module):
def __init__(
self,
ni: int,
nh: int,
):
super().__init__()
self.ni = ni
self.nh = nh
self.head_map = SimpleLinear(nh, 1, bias=False)
[docs]
def forward(
self,
AA: torch.Tensor, # nf x nloc x nnei x nnei x nh
h2: torch.Tensor, # nf x nloc x nnei x 3
) -> torch.Tensor:
nf, nloc, nnei, _ = h2.shape
nh = self.nh
# nf x nloc x nh x nnei x nnei
AA = torch.permute(AA, (0, 1, 4, 2, 3))
h2m = torch.unsqueeze(h2, dim=2)
# nf x nloc x nh x nnei x 3
h2m = torch.tile(h2m, [1, 1, nh, 1, 1])
# nf x nloc x nh x nnei x 3
ret = torch.matmul(AA, h2m)
# nf x nloc x nnei x 3 x nh
ret = torch.permute(ret, (0, 1, 3, 4, 2)).view(nf, nloc, nnei, 3, nh)
# nf x nloc x nnei x 3
return torch.squeeze(self.head_map(ret), dim=-1)
[docs]
class LocalAtten(torch.nn.Module):
def __init__(
self,
ni: int,
nd: int,
nh: int,
smooth: bool = True,
attnw_shift: float = 20.0,
):
super().__init__()
self.ni = ni
self.nd = nd
self.nh = nh
self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False)
self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False)
self.head_map = SimpleLinear(ni * nh, ni)
self.smooth = smooth
self.attnw_shift = attnw_shift
[docs]
def forward(
self,
g1: torch.Tensor, # nb x nloc x ng1
gg1: torch.Tensor, # nb x nloc x nnei x ng1
nlist_mask: torch.Tensor, # nb x nloc x nnei
sw: torch.Tensor, # nb x nloc x nnei
) -> torch.Tensor:
nb, nloc, nnei = nlist_mask.shape
ni, nd, nh = self.ni, self.nd, self.nh
assert ni == g1.shape[-1]
assert ni == gg1.shape[-1]
# nb x nloc x nd x nh
g1q = self.mapq(g1).view(nb, nloc, nd, nh)
# nb x nloc x nh x nd
g1q = torch.permute(g1q, (0, 1, 3, 2))
# nb x nloc x nnei x (nd+ni) x nh
gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh)
gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3))
# nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1
gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1)
# nb x nloc x nh x 1 x nnei
attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5
# nb x nloc x nh x nnei
attnw = attnw.squeeze(-2)
# mask the attenmap, nb x nloc x 1 x nnei
attnw_mask = ~nlist_mask.unsqueeze(-2)
# nb x nloc x nh x nnei
if self.smooth:
attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift
else:
attnw = attnw.masked_fill(
attnw_mask,
float("-inf"),
)
attnw = torch.softmax(attnw, dim=-1)
attnw = attnw.masked_fill(
attnw_mask,
0.0,
)
if self.smooth:
attnw = attnw * sw.unsqueeze(-2)
# nb x nloc x nh x ng1
ret = (
torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
)
# nb x nloc x ng1
ret = self.head_map(ret)
return ret
[docs]
class RepformerLayer(torch.nn.Module):
def __init__(
self,
rcut,
rcut_smth,
sel: int,
ntypes: int,
g1_dim=128,
g2_dim=16,
axis_dim: int = 4,
update_chnnl_2: bool = True,
do_bn_mode: str = "no",
bn_momentum: float = 0.1,
update_g1_has_conv: bool = True,
update_g1_has_drrd: bool = True,
update_g1_has_grrg: bool = True,
update_g1_has_attn: bool = True,
update_g2_has_g1g1: bool = True,
update_g2_has_attn: bool = True,
update_h2: bool = False,
attn1_hidden: int = 64,
attn1_nhead: int = 4,
attn2_hidden: int = 16,
attn2_nhead: int = 4,
attn2_has_gate: bool = False,
activation_function: str = "tanh",
update_style: str = "res_avg",
set_davg_zero: bool = True, # TODO
smooth: bool = True,
):
super().__init__()
self.epsilon = 1e-4 # protection of 1./nnei
self.rcut = rcut
self.rcut_smth = rcut_smth
self.ntypes = ntypes
sel = [sel] if isinstance(sel, int) else sel
self.nnei = sum(sel)
assert len(sel) == 1
self.sel = torch.tensor(sel, device=env.DEVICE)
self.sec = self.sel
self.axis_dim = axis_dim
self.set_davg_zero = set_davg_zero
self.do_bn_mode = do_bn_mode
self.bn_momentum = bn_momentum
self.act = ActivationFn(activation_function)
self.update_g1_has_grrg = update_g1_has_grrg
self.update_g1_has_drrd = update_g1_has_drrd
self.update_g1_has_conv = update_g1_has_conv
self.update_g1_has_attn = update_g1_has_attn
self.update_chnnl_2 = update_chnnl_2
self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False
self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False
self.update_h2 = update_h2 if self.update_chnnl_2 else False
del update_g2_has_g1g1, update_g2_has_attn, update_h2
self.update_style = update_style
self.smooth = smooth
self.g1_dim = g1_dim
self.g2_dim = g2_dim
g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_dim)
self.linear1 = SimpleLinear(g1_in_dim, g1_dim)
self.linear2 = None
self.proj_g1g2 = None
self.proj_g1g1g2 = None
self.attn2g_map = None
self.attn2_mh_apply = None
self.attn2_lm = None
self.attn2h_map = None
self.attn2_ev_apply = None
self.loc_attn = None
if self.update_chnnl_2:
self.linear2 = SimpleLinear(g2_dim, g2_dim)
if self.update_g1_has_conv:
self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False)
if self.update_g2_has_g1g1:
self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False)
if self.update_g2_has_attn:
self.attn2g_map = Atten2Map(
g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth
)
self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead)
self.attn2_lm = torch.nn.LayerNorm(
g2_dim,
elementwise_affine=True,
device=env.DEVICE,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
)
if self.update_h2:
self.attn2h_map = Atten2Map(
g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth
)
self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead)
if self.update_g1_has_attn:
self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth)
if self.do_bn_mode == "uniform":
self.bn1 = self._bn_layer()
self.bn2 = self._bn_layer()
elif self.do_bn_mode == "component":
self.bn1 = self._bn_layer(nf=g1_dim)
self.bn2 = self._bn_layer(nf=g2_dim)
elif self.do_bn_mode == "no":
self.bn1, self.bn2 = None, None
else:
raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}")
[docs]
def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int:
ret = g1d
if self.update_g1_has_grrg:
ret += g2d * ax
if self.update_g1_has_drrd:
ret += g1d * ax
if self.update_g1_has_conv:
ret += g2d
return ret
[docs]
def _update_h2(
self,
g2: torch.Tensor,
h2: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
) -> torch.Tensor:
assert self.attn2h_map is not None
assert self.attn2_ev_apply is not None
nb, nloc, nnei, _ = g2.shape
# # nb x nloc x nnei x nh2
# h2_1 = self.attn2_ev_apply(AA, h2)
# h2_update.append(h2_1)
# nb x nloc x nnei x nnei x nh
AAh = self.attn2h_map(g2, h2, nlist_mask, sw)
# nb x nloc x nnei x nh2
h2_1 = self.attn2_ev_apply(AAh, h2)
return h2_1
[docs]
def _update_g1_conv(
self,
gg1: torch.Tensor,
g2: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
) -> torch.Tensor:
assert self.proj_g1g2 is not None
nb, nloc, nnei, _ = g2.shape
ng1 = gg1.shape[-1]
ng2 = g2.shape[-1]
# gg1 : nb x nloc x nnei x ng2
gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2)
# nb x nloc x nnei x ng2
gg1 = _apply_nlist_mask(gg1, nlist_mask)
if not self.smooth:
# normalized by number of neighbors, not smooth
# nb x nloc x 1
invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1)
else:
gg1 = _apply_switch(gg1, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device
)
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
return g1_11
[docs]
def _cal_h2g2(
self,
g2: torch.Tensor,
h2: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
) -> torch.Tensor:
# g2: nf x nloc x nnei x ng2
# h2: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nb, nloc, nnei, _ = g2.shape
ng2 = g2.shape[-1]
# nb x nloc x nnei x ng2
g2 = _apply_nlist_mask(g2, nlist_mask)
if not self.smooth:
# nb x nloc
invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1))
# nb x nloc x 1 x 1
invnnei = invnnei.unsqueeze(-1).unsqueeze(-1)
else:
g2 = _apply_switch(g2, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device
)
# nb x nloc x 3 x ng2
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei
return h2g2
[docs]
def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor:
# nb x nloc x 3 x ng2
nb, nloc, _, ng2 = h2g2.shape
# nb x nloc x 3 x axis
h2g2m = torch.split(h2g2, self.axis_dim, dim=-1)[0]
# nb x nloc x axis x ng2
g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
# nb x nloc x (axisxng2)
g1_13 = g1_13.view(nb, nloc, self.axis_dim * ng2)
return g1_13
[docs]
def _update_g1_grrg(
self,
g2: torch.Tensor,
h2: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
) -> torch.Tensor:
# g2: nf x nloc x nnei x ng2
# h2: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nb, nloc, nnei, _ = g2.shape
ng2 = g2.shape[-1]
# nb x nloc x 3 x ng2
h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw)
# nb x nloc x (axisxng2)
g1_13 = self._cal_grrg(h2g2)
return g1_13
[docs]
def _update_g2_g1g1(
self,
g1: torch.Tensor, # nb x nloc x ng1
gg1: torch.Tensor, # nb x nloc x nnei x ng1
nlist_mask: torch.Tensor, # nb x nloc x nnei
sw: torch.Tensor, # nb x nloc x nnei
) -> torch.Tensor:
ret = g1.unsqueeze(-2) * gg1
# nb x nloc x nnei x ng1
ret = _apply_nlist_mask(ret, nlist_mask)
if self.smooth:
ret = _apply_switch(ret, sw)
return ret
[docs]
def _apply_bn(
self,
bn_number: int,
gg: torch.Tensor,
):
if self.do_bn_mode == "uniform":
return self._apply_bn_uni(bn_number, gg)
elif self.do_bn_mode == "component":
return self._apply_bn_comp(bn_number, gg)
else:
return gg
[docs]
def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor:
nb, nl, nf = gg.shape
gg = gg.view([nb, 1, nl * nf])
if bn_number == 1:
assert self.bn1 is not None
gg = self.bn1(gg)
else:
assert self.bn2 is not None
gg = self.bn2(gg)
return gg.view([nb, nl, nf])
[docs]
def _apply_nb_2(
self,
bn_number: int,
gg: torch.Tensor,
) -> torch.Tensor:
nb, nl, nnei, nf = gg.shape
gg = gg.view([nb, 1, nl * nnei * nf])
if bn_number == 1:
assert self.bn1 is not None
gg = self.bn1(gg)
else:
assert self.bn2 is not None
gg = self.bn2(gg)
return gg.view([nb, nl, nnei, nf])
[docs]
def _apply_bn_uni(
self,
bn_number: int,
gg: torch.Tensor,
mode: str = "1",
) -> torch.Tensor:
if len(gg.shape) == 3:
return self._apply_nb_1(bn_number, gg)
elif len(gg.shape) == 4:
return self._apply_nb_2(bn_number, gg)
else:
raise RuntimeError(f"unsupported input shape {gg.shape}")
[docs]
def _apply_bn_comp(
self,
bn_number: int,
gg: torch.Tensor,
) -> torch.Tensor:
ss = gg.shape
nf = ss[-1]
gg = gg.view([-1, nf])
if bn_number == 1:
assert self.bn1 is not None
gg = self.bn1(gg).view(ss)
else:
assert self.bn2 is not None
gg = self.bn2(gg).view(ss)
return gg
[docs]
def forward(
self,
g1_ext: torch.Tensor, # nf x nall x ng1
g2: torch.Tensor, # nf x nloc x nnei x ng2
h2: torch.Tensor, # nf x nloc x nnei x 3
nlist: torch.Tensor, # nf x nloc x nnei
nlist_mask: torch.Tensor, # nf x nloc x nnei
sw: torch.Tensor, # switch func, nf x nloc x nnei
):
"""
Parameters
----------
g1_ext : nf x nall x ng1 extended single-atom chanel
g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant
h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant
nlist : nf x nloc x nnei neighbor list (padded neis are set to 0)
nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0
sw : nf x nloc x nnei switch function
Returns
-------
g1: nf x nloc x ng1 updated single-atom chanel
g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant
h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant
"""
cal_gg1 = (
self.update_g1_has_drrd
or self.update_g1_has_conv
or self.update_g1_has_attn
or self.update_g2_has_g1g1
)
nb, nloc, nnei, _ = g2.shape
nall = g1_ext.shape[1]
g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1)
assert (nb, nloc) == g1.shape[:2]
assert (nb, nloc, nnei) == h2.shape[:3]
ng1 = g1.shape[-1]
ng2 = g2.shape[-1]
nh2 = h2.shape[-1]
if self.bn1 is not None:
g1 = self._apply_bn(1, g1)
if self.bn2 is not None:
g2 = self._apply_bn(2, g2)
if self.update_h2:
h2 = _apply_h_norm(h2)
g2_update: List[torch.Tensor] = [g2]
h2_update: List[torch.Tensor] = [h2]
g1_update: List[torch.Tensor] = [g1]
g1_mlp: List[torch.Tensor] = [g1]
if cal_gg1:
gg1 = _make_nei_g1(g1_ext, nlist)
else:
gg1 = None
if self.update_chnnl_2:
# nb x nloc x nnei x ng2
assert self.linear2 is not None
g2_1 = self.act(self.linear2(g2))
g2_update.append(g2_1)
if self.update_g2_has_g1g1:
assert gg1 is not None
assert self.proj_g1g1g2 is not None
g2_update.append(
self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw))
)
if self.update_g2_has_attn:
assert self.attn2g_map is not None
assert self.attn2_mh_apply is not None
assert self.attn2_lm is not None
# nb x nloc x nnei x nnei x nh
AAg = self.attn2g_map(g2, h2, nlist_mask, sw)
# nb x nloc x nnei x ng2
g2_2 = self.attn2_mh_apply(AAg, g2)
g2_2 = self.attn2_lm(g2_2)
g2_update.append(g2_2)
if self.update_h2:
h2_update.append(self._update_h2(g2, h2, nlist_mask, sw))
if self.update_g1_has_conv:
assert gg1 is not None
g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw))
if self.update_g1_has_grrg:
g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw))
if self.update_g1_has_drrd:
assert gg1 is not None
g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw))
# nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)]
# conv grrg drrd
g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1)))
g1_update.append(g1_1)
if self.update_g1_has_attn:
assert gg1 is not None
assert self.loc_attn is not None
g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw))
# update
if self.update_chnnl_2:
g2_new = self.list_update(g2_update)
h2_new = self.list_update(h2_update)
else:
g2_new, h2_new = g2, h2
g1_new = self.list_update(g1_update)
return g1_new, g2_new, h2_new
@torch.jit.export
[docs]
def list_update_res_avg(
self,
update_list: List[torch.Tensor],
) -> torch.Tensor:
nitem = len(update_list)
uu = update_list[0]
for ii in range(1, nitem):
uu = uu + update_list[ii]
return uu / (float(nitem) ** 0.5)
@torch.jit.export
[docs]
def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor:
nitem = len(update_list)
uu = update_list[0]
scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0
for ii in range(1, nitem):
uu = uu + scale * update_list[ii]
return uu
@torch.jit.export
[docs]
def list_update(self, update_list: List[torch.Tensor]) -> torch.Tensor:
if self.update_style == "res_avg":
return self.list_update_res_avg(update_list)
elif self.update_style == "res_incr":
return self.list_update_res_incr(update_list)
else:
raise RuntimeError(f"unknown update style {self.update_style}")
[docs]
def _bn_layer(
self,
nf: int = 1,
) -> Callable:
return torch.nn.BatchNorm1d(
nf,
eps=1e-5,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
device=env.DEVICE,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
)