Source code for deepmd.pt.loss.tensor

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
)

import torch

from deepmd.pt.loss.loss import (
    TaskLoss,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.utils.data import (
    DataRequirementItem,
)


[docs] class TensorLoss(TaskLoss): def __init__( self, tensor_name: str, tensor_size: int, label_name: str, pref_atomic: float = 0.0, pref: float = 0.0, inference=False, **kwargs, ): r"""Construct a loss for local and global tensors. Parameters ---------- tensor_name : str The name of the tensor in the model predictions to compute the loss. tensor_size : int The size (dimension) of the tensor. label_name : str The name of the tensor in the labels to compute the loss. pref_atomic : float The prefactor of the weight of atomic loss. It should be larger than or equal to 0. pref : float The prefactor of the weight of global loss. It should be larger than or equal to 0. inference : bool If true, it will output all losses found in output, ignoring the pre-factors. **kwargs Other keyword arguments. """ super().__init__() self.tensor_name = tensor_name self.tensor_size = tensor_size self.label_name = label_name self.local_weight = pref_atomic self.global_weight = pref self.inference = inference assert ( self.local_weight >= 0.0 and self.global_weight >= 0.0 ), "Can not assign negative weight to `pref` and `pref_atomic`" self.has_local_weight = self.local_weight > 0.0 or inference self.has_global_weight = self.global_weight > 0.0 or inference assert self.has_local_weight or self.has_global_weight, AssertionError( "Can not assian zero weight both to `pref` and `pref_atomic`" )
[docs] def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False): """Return loss on local and global tensors. Parameters ---------- input_dict : dict[str, torch.Tensor] Model inputs. model : torch.nn.Module Model to be used to output the predictions. label : dict[str, torch.Tensor] Labels. natoms : int The local atom number. Returns ------- model_pred: dict[str, torch.Tensor] Model predictions. loss: torch.Tensor Loss for model to minimize. more_loss: dict[str, torch.Tensor] Other losses for display. """ model_pred = model(**input_dict) del learning_rate, mae loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} if ( self.has_local_weight and self.tensor_name in model_pred and "atom_" + self.label_name in label ): find_local = label.get("find_" + "atom_" + self.label_name, 0.0) local_weight = self.local_weight * find_local local_tensor_pred = model_pred[self.tensor_name].reshape( [-1, natoms, self.tensor_size] ) local_tensor_label = label["atom_" + self.label_name].reshape( [-1, natoms, self.tensor_size] ) diff = (local_tensor_pred - local_tensor_label).reshape( [-1, self.tensor_size] ) if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) if not self.inference: more_loss[f"l2_local_{self.tensor_name}_loss"] = self.display_if_exist( l2_local_loss.detach(), find_local ) loss += local_weight * l2_local_loss rmse_local = l2_local_loss.sqrt() more_loss[f"rmse_local_{self.tensor_name}"] = self.display_if_exist( rmse_local.detach(), find_local ) if ( self.has_global_weight and "global_" + self.tensor_name in model_pred and self.label_name in label ): find_global = label.get("find_" + self.label_name, 0.0) global_weight = self.global_weight * find_global global_tensor_pred = model_pred["global_" + self.tensor_name].reshape( [-1, self.tensor_size] ) global_tensor_label = label[self.label_name].reshape([-1, self.tensor_size]) diff = global_tensor_pred - global_tensor_label if "mask" in model_pred: atom_num = model_pred["mask"].sum(-1, keepdim=True) l2_global_loss = torch.mean( torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum() ) atom_num = torch.mean(atom_num.float()) else: atom_num = natoms l2_global_loss = torch.mean(torch.square(diff)) if not self.inference: more_loss[f"l2_global_{self.tensor_name}_loss"] = self.display_if_exist( l2_global_loss.detach(), find_global ) loss += global_weight * l2_global_loss rmse_global = l2_global_loss.sqrt() / atom_num more_loss[f"rmse_global_{self.tensor_name}"] = self.display_if_exist( rmse_global.detach(), find_global ) return model_pred, loss, more_loss
@property
[docs] def label_requirement(self) -> List[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" label_requirement = [] if self.has_local_weight: label_requirement.append( DataRequirementItem( "atomic_" + self.label_name, ndof=self.tensor_size, atomic=True, must=False, high_prec=False, ) ) if self.has_global_weight: label_requirement.append( DataRequirementItem( self.label_name, ndof=self.tensor_size, atomic=False, must=False, high_prec=False, ) ) return label_requirement