Source code for deepmd.pt.model.task.type_predict

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Optional,
)

import torch

from deepmd.pt.model.network.network import (
    MaskLMHead,
)
from deepmd.pt.model.task import (
    Fitting,
)


[docs] class TypePredictNet(Fitting): def __init__(self, feature_dim, ntypes, activation_function="gelu", **kwargs): """Construct a type predict net. Args: - feature_dim: Input dm. - ntypes: Numer of types to predict. - activation_function: Activate function. """ super().__init__() self.feature_dim = feature_dim self.ntypes = ntypes self.lm_head = MaskLMHead( embed_dim=self.feature_dim, output_dim=ntypes, activation_fn=activation_function, weight=None, )
[docs] def forward(self, features, masked_tokens: Optional[torch.Tensor] = None): """Calculate the predicted logits. Args: - features: Input features with shape [nframes, nloc, feature_dim]. - masked_tokens: Input masked tokens with shape [nframes, nloc]. Returns ------- - logits: Predicted probs with shape [nframes, nloc, ntypes]. """ # [nframes, nloc, ntypes] logits = self.lm_head(features, masked_tokens=masked_tokens) return logits