# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
List,
)
import torch
from deepmd.utils.data import (
DataRequirementItem,
)
[docs]
class TaskLoss(torch.nn.Module, ABC):
def __init__(self, **kwargs):
"""Construct loss."""
super().__init__()
[docs]
def forward(self, input_dict, model, label, natoms, learning_rate):
"""Return loss ."""
raise NotImplementedError
@property
@abstractmethod
[docs]
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
pass
@staticmethod
[docs]
def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor:
"""Display NaN if labeled property is not found.
Parameters
----------
loss : torch.Tensor
the loss tensor
find_property : float
whether the property is found
"""
return loss if bool(find_property) else torch.nan