Source code for deepmd.pt.loss.dos

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
)

import torch

from deepmd.pt.loss.loss import (
    TaskLoss,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.utils.data import (
    DataRequirementItem,
)


[docs] class DOSLoss(TaskLoss): def __init__( self, starter_learning_rate: float, numb_dos: int, start_pref_dos: float = 1.00, limit_pref_dos: float = 1.00, start_pref_cdf: float = 1000, limit_pref_cdf: float = 1.00, start_pref_ados: float = 0.0, limit_pref_ados: float = 0.0, start_pref_acdf: float = 0.0, limit_pref_acdf: float = 0.0, inference=False, **kwargs, ): r"""Construct a loss for local and global tensors. Parameters ---------- tensor_name : str The name of the tensor in the model predictions to compute the loss. tensor_size : int The size (dimension) of the tensor. label_name : str The name of the tensor in the labels to compute the loss. pref_atomic : float The prefactor of the weight of atomic loss. It should be larger than or equal to 0. pref : float The prefactor of the weight of global loss. It should be larger than or equal to 0. inference : bool If true, it will output all losses found in output, ignoring the pre-factors. **kwargs Other keyword arguments. """ super().__init__() self.starter_learning_rate = starter_learning_rate self.numb_dos = numb_dos self.inference = inference self.start_pref_dos = start_pref_dos self.limit_pref_dos = limit_pref_dos self.start_pref_cdf = start_pref_cdf self.limit_pref_cdf = limit_pref_cdf self.start_pref_ados = start_pref_ados self.limit_pref_ados = limit_pref_ados self.start_pref_acdf = start_pref_acdf self.limit_pref_acdf = limit_pref_acdf assert ( self.start_pref_dos >= 0.0 and self.limit_pref_dos >= 0.0 and self.start_pref_cdf >= 0.0 and self.limit_pref_cdf >= 0.0 and self.start_pref_ados >= 0.0 and self.limit_pref_ados >= 0.0 and self.start_pref_acdf >= 0.0 and self.limit_pref_acdf >= 0.0 ), "Can not assign negative weight to `pref` and `pref_atomic`" self.has_dos = (start_pref_dos != 0.0 and limit_pref_dos != 0.0) or inference self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference self.has_ados = (start_pref_ados != 0.0 and limit_pref_ados != 0.0) or inference self.has_acdf = (start_pref_acdf != 0.0 and limit_pref_acdf != 0.0) or inference assert ( self.has_dos or self.has_cdf or self.has_ados or self.has_acdf ), AssertionError("Can not assian zero weight both to `pref` and `pref_atomic`")
[docs] def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False): """Return loss on local and global tensors. Parameters ---------- input_dict : dict[str, torch.Tensor] Model inputs. model : torch.nn.Module Model to be used to output the predictions. label : dict[str, torch.Tensor] Labels. natoms : int The local atom number. Returns ------- model_pred: dict[str, torch.Tensor] Model predictions. loss: torch.Tensor Loss for model to minimize. more_loss: dict[str, torch.Tensor] Other losses for display. """ model_pred = model(**input_dict) coef = learning_rate / self.starter_learning_rate pref_dos = ( self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef ) pref_cdf = ( self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef ) pref_ados = ( self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef ) pref_acdf = ( self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef ) loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} if self.has_ados and "atom_dos" in model_pred and "atom_dos" in label: find_local = label.get("find_atom_dos", 0.0) pref_ados = pref_ados * find_local local_tensor_pred_dos = model_pred["atom_dos"].reshape( [-1, natoms, self.numb_dos] ) local_tensor_label_dos = label["atom_dos"].reshape( [-1, natoms, self.numb_dos] ) diff = (local_tensor_pred_dos - local_tensor_label_dos).reshape( [-1, self.numb_dos] ) if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss_dos = torch.mean(torch.square(diff)) if not self.inference: more_loss["l2_local_dos_loss"] = self.display_if_exist( l2_local_loss_dos.detach(), find_local ) loss += pref_ados * l2_local_loss_dos rmse_local_dos = l2_local_loss_dos.sqrt() more_loss["rmse_local_dos"] = self.display_if_exist( rmse_local_dos.detach(), find_local ) if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label: find_local = label.get("find_atom_dos", 0.0) pref_acdf = pref_acdf * find_local local_tensor_pred_cdf = torch.cusum( model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1 ) local_tensor_label_cdf = torch.cusum( label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1 ) diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape( [-1, self.numb_dos] ) if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss_cdf = torch.mean(torch.square(diff)) if not self.inference: more_loss["l2_local_cdf_loss"] = self.display_if_exist( l2_local_loss_cdf.detach(), find_local ) loss += pref_acdf * l2_local_loss_cdf rmse_local_cdf = l2_local_loss_cdf.sqrt() more_loss["rmse_local_cdf"] = self.display_if_exist( rmse_local_cdf.detach(), find_local ) if self.has_dos and "dos" in model_pred and "dos" in label: find_global = label.get("find_dos", 0.0) pref_dos = pref_dos * find_global global_tensor_pred_dos = model_pred["dos"].reshape([-1, self.numb_dos]) global_tensor_label_dos = label["dos"].reshape([-1, self.numb_dos]) diff = global_tensor_pred_dos - global_tensor_label_dos if "mask" in model_pred: atom_num = model_pred["mask"].sum(-1, keepdim=True) l2_global_loss_dos = torch.mean( torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum() ) atom_num = torch.mean(atom_num.float()) else: atom_num = natoms l2_global_loss_dos = torch.mean(torch.square(diff)) if not self.inference: more_loss["l2_global_dos_loss"] = self.display_if_exist( l2_global_loss_dos.detach(), find_global ) loss += pref_dos * l2_global_loss_dos rmse_global_dos = l2_global_loss_dos.sqrt() / atom_num more_loss["rmse_global_dos"] = self.display_if_exist( rmse_global_dos.detach(), find_global ) if self.has_cdf and "dos" in model_pred and "dos" in label: find_global = label.get("find_dos", 0.0) pref_cdf = pref_cdf * find_global global_tensor_pred_cdf = torch.cusum( model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1 ) global_tensor_label_cdf = torch.cusum( label["dos"].reshape([-1, self.numb_dos]), dim=-1 ) diff = global_tensor_pred_cdf - global_tensor_label_cdf if "mask" in model_pred: atom_num = model_pred["mask"].sum(-1, keepdim=True) l2_global_loss_cdf = torch.mean( torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum() ) atom_num = torch.mean(atom_num.float()) else: atom_num = natoms l2_global_loss_cdf = torch.mean(torch.square(diff)) if not self.inference: more_loss["l2_global_cdf_loss"] = self.display_if_exist( l2_global_loss_cdf.detach(), find_global ) loss += pref_cdf * l2_global_loss_cdf rmse_global_dos = l2_global_loss_cdf.sqrt() / atom_num more_loss["rmse_global_cdf"] = self.display_if_exist( rmse_global_dos.detach(), find_global ) return model_pred, loss, more_loss
@property
[docs] def label_requirement(self) -> List[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" label_requirement = [] if self.has_ados or self.has_acdf: label_requirement.append( DataRequirementItem( "atom_dos", ndof=self.numb_dos, atomic=True, must=False, high_prec=False, ) ) if self.has_dos or self.has_cdf: label_requirement.append( DataRequirementItem( "dos", ndof=self.numb_dos, atomic=False, must=False, high_prec=False, ) ) return label_requirement