Source code for deepmd.pt.optimizer.KFWrapper

# SPDX-License-Identifier: LGPL-3.0-or-later
import math

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim.optimizer import (
    Optimizer,
)


[docs] class KFOptimizerWrapper: def __init__( self, model: nn.Module, optimizer: Optimizer, atoms_selected: int, atoms_per_group: int, is_distributed: bool = False, ) -> None: self.model = model self.optimizer = optimizer self.atoms_selected = atoms_selected # 24 self.atoms_per_group = atoms_per_group # 6 self.is_distributed = is_distributed
[docs] def update_energy( self, inputs: dict, Etot_label: torch.Tensor, update_prefactor: float = 1 ) -> None: model_pred, _, _ = self.model(**inputs, inference_only=True) Etot_predict = model_pred["energy"] natoms_sum = int(inputs["atype"].shape[-1]) self.optimizer.set_grad_prefactor(natoms_sum) self.optimizer.zero_grad() bs = Etot_label.shape[0] error = Etot_label - Etot_predict error = error / natoms_sum mask = error < 0 error = error * update_prefactor error[mask] = -1 * error[mask] error = error.mean() if self.is_distributed: dist.all_reduce(error) error /= dist.get_world_size() Etot_predict = update_prefactor * Etot_predict Etot_predict[mask] = -Etot_predict[mask] Etot_predict.sum().backward() error = error * math.sqrt(bs) self.optimizer.step(error) return Etot_predict
[docs] def update_force( self, inputs: dict, Force_label: torch.Tensor, update_prefactor: float = 1 ) -> None: natoms_sum = int(inputs["atype"].shape[-1]) bs = Force_label.shape[0] self.optimizer.set_grad_prefactor(natoms_sum * self.atoms_per_group * 3) index = self.__sample(self.atoms_selected, self.atoms_per_group, natoms_sum) for i in range(index.shape[0]): self.optimizer.zero_grad() model_pred, _, _ = self.model(**inputs, inference_only=True) Etot_predict = model_pred["energy"] natoms_sum = int(inputs["atype"].shape[-1]) force_predict = model_pred["force"] error_tmp = Force_label[:, index[i]] - force_predict[:, index[i]] error_tmp = update_prefactor * error_tmp mask = error_tmp < 0 error_tmp[mask] = -1 * error_tmp[mask] error = error_tmp.mean() / natoms_sum if self.is_distributed: dist.all_reduce(error) error /= dist.get_world_size() tmp_force_predict = force_predict[:, index[i]] * update_prefactor tmp_force_predict[mask] = -tmp_force_predict[mask] # In order to solve a pytorch bug, reference: https://github.com/pytorch/pytorch/issues/43259 (tmp_force_predict.sum() + Etot_predict.sum() * 0).backward() error = error * math.sqrt(bs) self.optimizer.step(error) return Etot_predict, force_predict
[docs] def update_denoise_coord( self, inputs: dict, clean_coord: torch.Tensor, update_prefactor: float = 1, mask_loss_coord: bool = True, coord_mask: torch.Tensor = None, ) -> None: natoms_sum = int(inputs["atype"].shape[-1]) bs = clean_coord.shape[0] self.optimizer.set_grad_prefactor(natoms_sum * self.atoms_per_group * 3) index = self.__sample(self.atoms_selected, self.atoms_per_group, natoms_sum) for i in range(index.shape[0]): self.optimizer.zero_grad() model_pred, _, _ = self.model(**inputs, inference_only=True) updated_coord = model_pred["updated_coord"] natoms_sum = int(inputs["atype"].shape[-1]) error_tmp = clean_coord[:, index[i]] - updated_coord[:, index[i]] error_tmp = update_prefactor * error_tmp if mask_loss_coord: error_tmp[~coord_mask[:, index[i]]] = 0 mask = error_tmp < 0 error_tmp[mask] = -1 * error_tmp[mask] error = error_tmp.mean() / natoms_sum if self.is_distributed: dist.all_reduce(error) error /= dist.get_world_size() tmp_coord_predict = updated_coord[:, index[i]] * update_prefactor tmp_coord_predict[mask] = -update_prefactor * tmp_coord_predict[mask] # In order to solve a pytorch bug, reference: https://github.com/pytorch/pytorch/issues/43259 (tmp_coord_predict.sum() + updated_coord.sum() * 0).backward() error = error * math.sqrt(bs) self.optimizer.step(error) return model_pred
[docs] def __sample( self, atoms_selected: int, atoms_per_group: int, natoms: int ) -> np.ndarray: if atoms_selected % atoms_per_group: raise Exception("divider") index = range(natoms) rng = np.random.default_rng() res = rng.choice(index, atoms_selected).reshape(-1, atoms_per_group) return res
# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof: # the code u wanna profile # print(prof.key_averages().table(sort_by="self_cpu_time_total"))