Source code for deepmd.pt.model.task.atten_lcc

# SPDX-License-Identifier: LGPL-3.0-or-later
import torch
import torch.nn as nn

from deepmd.pt.model.network.network import (
    EnergyHead,
    NodeTaskHead,
)
from deepmd.pt.model.task.fitting import (
    Fitting,
)
from deepmd.pt.utils import (
    env,
)


[docs] class FittingNetAttenLcc(Fitting): def __init__( self, embedding_width, bias_atom_e, pair_embed_dim, attention_heads, **kwargs ): super().__init__() self.embedding_width = embedding_width self.engergy_proj = EnergyHead(self.embedding_width, 1) self.energe_agg_factor = nn.Embedding(4, 1, dtype=env.GLOBAL_PT_FLOAT_PRECISION) nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01) bias_atom_e = torch.tensor(bias_atom_e) self.register_buffer("bias_atom_e", bias_atom_e) self.pair_embed_dim = pair_embed_dim self.attention_heads = attention_heads self.node_proc = NodeTaskHead( self.embedding_width, self.pair_embed_dim, self.attention_heads ) self.node_proc.zero_init()
[docs] def forward(self, output, pair, delta_pos, atype, nframes, nloc): # [nframes x nloc x tebd_dim] output_nloc = (output[:, 0, :]).reshape(nframes, nloc, self.embedding_width) # Optional: GRRG or mean of gbf TODO # energy outut # [nframes, nloc] energy_out = self.engergy_proj(output_nloc).view(nframes, nloc) # [nframes, nloc] energy_factor = self.energe_agg_factor(torch.zeros_like(atype)).view( nframes, nloc ) energy_out = (energy_out * energy_factor) + self.bias_atom_e[atype] energy_out = energy_out.sum(dim=-1) # vector output # predict_force: [(nframes x nloc) x (1 + nnei2) x 3] predict_force = self.node_proc(output, pair, delta_pos=delta_pos) # predict_force_nloc: [nframes x nloc x 3] predict_force_nloc = (predict_force[:, 0, :]).reshape(nframes, nloc, 3) return energy_out, predict_force_nloc