Source code for deepmd.pt.loss.ener_spin

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
)

import torch
import torch.nn.functional as F

from deepmd.pt.loss.loss import (
    TaskLoss,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.env import (
    GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.utils.data import (
    DataRequirementItem,
)


[docs] class EnergySpinLoss(TaskLoss): def __init__( self, starter_learning_rate=1.0, start_pref_e=0.0, limit_pref_e=0.0, start_pref_fr=0.0, limit_pref_fr=0.0, start_pref_fm=0.0, limit_pref_fm=0.0, start_pref_v=0.0, limit_pref_v=0.0, start_pref_ae: float = 0.0, limit_pref_ae: float = 0.0, start_pref_pf: float = 0.0, limit_pref_pf: float = 0.0, use_l1_all: bool = False, inference=False, **kwargs, ): """Construct a layer to compute loss on energy, real force, magnetic force and virial.""" super().__init__() self.starter_learning_rate = starter_learning_rate self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference self.has_fr = (start_pref_fr != 0.0 and limit_pref_fr != 0.0) or inference self.has_fm = (start_pref_fm != 0.0 and limit_pref_fm != 0.0) or inference # TODO EnergySpinLoss needs support for virial, atomic energy and atomic pref self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference self.start_pref_e = start_pref_e self.limit_pref_e = limit_pref_e self.start_pref_fr = start_pref_fr self.limit_pref_fr = limit_pref_fr self.start_pref_fm = start_pref_fm self.limit_pref_fm = limit_pref_fm self.start_pref_v = start_pref_v self.limit_pref_v = limit_pref_v self.use_l1_all = use_l1_all self.inference = inference
[docs] def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): """Return energy loss with magnetic labels. Parameters ---------- input_dict : dict[str, torch.Tensor] Model inputs. model : torch.nn.Module Model to be used to output the predictions. label : dict[str, torch.Tensor] Labels. natoms : int The local atom number. Returns ------- model_pred: dict[str, torch.Tensor] Model predictions. loss: torch.Tensor Loss for model to minimize. more_loss: dict[str, torch.Tensor] Other losses for display. """ model_pred = model(**input_dict) coef = learning_rate / self.starter_learning_rate pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef pref_fm = self.limit_pref_fm + (self.start_pref_fm - self.limit_pref_fm) * coef pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) more_loss = {} # more_loss['log_keys'] = [] # showed when validation on the fly # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: find_energy = label.get("find_energy", 0.0) pref_e = pref_e * find_energy if not self.use_l1_all: l2_ener_loss = torch.mean( torch.square(model_pred["energy"] - label["energy"]) ) if not self.inference: more_loss["l2_ener_loss"] = self.display_if_exist( l2_ener_loss.detach(), find_energy ) loss += atom_norm * (pref_e * l2_ener_loss) rmse_e = l2_ener_loss.sqrt() * atom_norm more_loss["rmse_e"] = self.display_if_exist( rmse_e.detach(), find_energy ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( model_pred["energy"].reshape(-1), label["energy"].reshape(-1), reduction="sum", ) loss += pref_e * l1_ener_loss more_loss["mae_e"] = self.display_if_exist( F.l1_loss( model_pred["energy"].reshape(-1), label["energy"].reshape(-1), reduction="mean", ).detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') if mae: mae_e = ( torch.mean(torch.abs(model_pred["energy"] - label["energy"])) * atom_norm ) more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) mae_e_all = torch.mean( torch.abs(model_pred["energy"] - label["energy"]) ) more_loss["mae_e_all"] = self.display_if_exist( mae_e_all.detach(), find_energy ) if self.has_fr and "force" in model_pred and "force" in label: find_force_r = label.get("find_force", 0.0) pref_fr = pref_fr * find_force_r if not self.use_l1_all: diff_fr = label["force"] - model_pred["force"] l2_force_real_loss = torch.mean(torch.square(diff_fr)) if not self.inference: more_loss["l2_force_r_loss"] = self.display_if_exist( l2_force_real_loss.detach(), find_force_r ) loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_fr = l2_force_real_loss.sqrt() more_loss["rmse_fr"] = self.display_if_exist( rmse_fr.detach(), find_force_r ) if mae: mae_fr = torch.mean(torch.abs(diff_fr)) more_loss["mae_fr"] = self.display_if_exist( mae_fr.detach(), find_force_r ) else: l1_force_real_loss = F.l1_loss( label["force"], model_pred["force"], reduction="none" ) more_loss["mae_fr"] = self.display_if_exist( l1_force_real_loss.mean().detach(), find_force_r ) l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum() loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) if self.has_fm and "force_mag" in model_pred and "force_mag" in label: find_force_m = label.get("find_force_mag", 0.0) pref_fm = pref_fm * find_force_m nframes = model_pred["force_mag"].shape[0] atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3]) label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3) model_pred_force_mag = model_pred["force_mag"][atomic_mask].view( nframes, -1, 3 ) if not self.use_l1_all: diff_fm = label_force_mag - model_pred_force_mag l2_force_mag_loss = torch.mean(torch.square(diff_fm)) if not self.inference: more_loss["l2_force_m_loss"] = self.display_if_exist( l2_force_mag_loss.detach(), find_force_m ) loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_fm = l2_force_mag_loss.sqrt() more_loss["rmse_fm"] = self.display_if_exist( rmse_fm.detach(), find_force_m ) if mae: mae_fm = torch.mean(torch.abs(diff_fm)) more_loss["mae_fm"] = self.display_if_exist( mae_fm.detach(), find_force_m ) else: l1_force_mag_loss = F.l1_loss( label_force_mag, model_pred_force_mag, reduction="none" ) more_loss["mae_fm"] = self.display_if_exist( l1_force_mag_loss.mean().detach(), find_force_m ) l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum() loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss
@property
[docs] def label_requirement(self) -> List[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" label_requirement = [] if self.has_e: label_requirement.append( DataRequirementItem( "energy", ndof=1, atomic=False, must=False, high_prec=True, ) ) if self.has_fr: label_requirement.append( DataRequirementItem( "force", ndof=3, atomic=True, must=False, high_prec=False, ) ) if self.has_fm: label_requirement.append( DataRequirementItem( "force_mag", ndof=3, atomic=True, must=False, high_prec=False, ) ) if self.has_v: label_requirement.append( DataRequirementItem( "virial", ndof=9, atomic=False, must=False, high_prec=False, ) ) if self.has_ae: label_requirement.append( DataRequirementItem( "atom_ener", ndof=1, atomic=True, must=False, high_prec=False, ) ) if self.has_pf: label_requirement.append( DataRequirementItem( "atom_pref", ndof=1, atomic=True, must=False, high_prec=False, repeat=3, ) ) return label_requirement