Source code for deepmd.pt.optimizer.LKF

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import math

import torch
import torch.distributed as dist
from torch.optim.optimizer import (
    Optimizer,
)


[docs] def distribute_indices(total_length, num_workers): indices_per_worker = total_length // num_workers remainder = total_length % num_workers indices = [] start = 0 for i in range(num_workers): end = start + indices_per_worker + (1 if i < remainder else 0) indices.append((start, end)) start = end return indices, remainder
[docs] class LKFOptimizer(Optimizer): def __init__( self, params, kalman_lambda=0.98, kalman_nue=0.9987, block_size=5120, ): defaults = {"lr": 0.1, "kalman_nue": kalman_nue, "block_size": block_size} super().__init__(params, defaults) self._params = self.param_groups[0]["params"] if len(self.param_groups) != 1 or len(self._params) == 0: raise ValueError( "LKF doesn't support per-parameter options " "(parameter groups)" ) # NOTE: LKF has only global state, but we register it as state for # the first param, because this helps with casting in load_state_dict self._state = self.state[self._params[0]] self._state.setdefault("kalman_lambda", kalman_lambda) self.dist_init = dist.is_available() and dist.is_initialized() self.rank = dist.get_rank() if self.dist_init else 0 self.dindex = [] self.remainder = 0 self.__init_P()
[docs] def __init_P(self): param_nums = [] param_sum = 0 block_size = self.__get_blocksize() data_type = self._params[0].dtype device = self._params[0].device for param_group in self.param_groups: params = param_group["params"] for param in params: param_num = param.data.nelement() if param_sum + param_num > block_size: if param_sum > 0: param_nums.append(param_sum) param_sum = param_num else: param_sum += param_num param_nums.append(param_sum) P = [] params_packed_index = [] logging.info(f"LKF parameter nums: {param_nums}") if self.dist_init: block_num = 0 for param_num in param_nums: if param_num >= block_size: block_num += math.ceil(param_num / block_size) else: block_num += 1 num_workers = dist.get_world_size() self.dindex, self.remainder = distribute_indices(block_num, num_workers) index = 0 for param_num in param_nums: if param_num >= block_size: block_num = math.ceil(param_num / block_size) for i in range(block_num): device_id = self.get_device_id(index) index += 1 dist_device = torch.device("cuda:" + str(device_id)) if i != block_num - 1: params_packed_index.append(block_size) if self.rank == device_id: P.append( torch.eye( block_size, dtype=data_type, device=dist_device, ) ) else: continue else: params_packed_index.append(param_num - block_size * i) if self.rank == device_id: P.append( torch.eye( param_num - block_size * i, dtype=data_type, device=dist_device, ) ) else: continue else: device_id = self.get_device_id(index) index += 1 params_packed_index.append(param_num) if self.rank == device_id: dist_device = torch.device("cuda:" + str(device_id)) P.append( torch.eye(param_num, dtype=data_type, device=dist_device) ) else: for param_num in param_nums: if param_num >= block_size: block_num = math.ceil(param_num / block_size) for i in range(block_num): if i != block_num - 1: P.append( torch.eye( block_size, dtype=data_type, device=device, ) ) params_packed_index.append(block_size) else: P.append( torch.eye( param_num - block_size * i, dtype=data_type, device=device, ) ) params_packed_index.append(param_num - block_size * i) else: P.append(torch.eye(param_num, dtype=data_type, device=device)) params_packed_index.append(param_num) self._state.setdefault("P", P) self._state.setdefault("weights_num", len(P)) self._state.setdefault("params_packed_index", params_packed_index)
[docs] def __get_blocksize(self): return self.param_groups[0]["block_size"]
[docs] def __get_nue(self): return self.param_groups[0]["kalman_nue"]
[docs] def __split_weights(self, weight): block_size = self.__get_blocksize() param_num = weight.nelement() res = [] if param_num < block_size: res.append(weight) else: block_num = math.ceil(param_num / block_size) for i in range(block_num): if i != block_num - 1: res.append(weight[i * block_size : (i + 1) * block_size]) else: res.append(weight[i * block_size :]) return res
[docs] def __update(self, H, error, weights): P = self._state.get("P") kalman_lambda = self._state.get("kalman_lambda") weights_num = self._state.get("weights_num") params_packed_index = self._state.get("params_packed_index") block_size = self.__get_blocksize() kalman_nue = self.__get_nue() tmp = 0 for i in range(weights_num): tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i])) if self.dist_init: dist.all_reduce(tmp, op=dist.ReduceOp.SUM) A = 1 / tmp for i in range(weights_num): K = torch.matmul(P[i], H[i]) weights[i] = weights[i] + A * error * K P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T)) if self.dist_init: device = torch.device("cuda:" + str(self.rank)) local_shape = [tensor.shape[0] for tensor in weights] shape_list = [ torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device) for _ in range(dist.get_world_size()) ] dist.all_gather_object(shape_list, local_shape) weight_tensor = torch.cat(weights) world_shape = [sum(inner_list) for inner_list in shape_list] weight_list = [None] * len(world_shape) for i in range(len(world_shape)): weight_list[i] = torch.zeros( world_shape[i], dtype=torch.float64, device=device ) dist.all_gather(weight_list, weight_tensor) result = [] for i in range(dist.get_world_size()): result = result + list(torch.split(weight_list[i], shape_list[i])) weights = result kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue self._state.update({"kalman_lambda": kalman_lambda}) i = 0 param_sum = 0 for param_group in self.param_groups: params = param_group["params"] for param in params: param_num = param.nelement() weight_tmp = weights[i][param_sum : param_sum + param_num] if param_num < block_size: if param.ndim > 1: param.data = weight_tmp.reshape( param.data.T.shape ).T.contiguous() else: param.data = weight_tmp.reshape(param.data.shape) param_sum += param_num if param_sum == params_packed_index[i]: i += 1 param_sum = 0 else: block_num = math.ceil(param_num / block_size) for j in range(block_num): if j == 0: tmp_weight = weights[i] else: tmp_weight = torch.concat([tmp_weight, weights[i]], dim=0) i += 1 param.data = tmp_weight.reshape(param.data.T.shape).T.contiguous()
[docs] def set_grad_prefactor(self, grad_prefactor): self.grad_prefactor = grad_prefactor
[docs] def step(self, error): params_packed_index = self._state.get("params_packed_index") weights = [] H = [] param_index = 0 param_sum = 0 for param in self._params: if param.ndim > 1: tmp = param.data.T.contiguous().reshape(param.data.nelement(), 1) if param.grad is None: tmp_grad = torch.zeros_like(tmp) else: tmp_grad = ( (param.grad / self.grad_prefactor) .T.contiguous() .reshape(param.grad.nelement(), 1) ) else: tmp = param.data.reshape(param.data.nelement(), 1) if param.grad is None: tmp_grad = torch.zeros_like(tmp) else: tmp_grad = (param.grad / self.grad_prefactor).reshape( param.grad.nelement(), 1 ) tmp = self.__split_weights(tmp) tmp_grad = self.__split_weights(tmp_grad) for split_grad, split_weight in zip(tmp_grad, tmp): nelement = split_grad.nelement() if param_sum == 0: res_grad = split_grad res = split_weight else: res_grad = torch.concat((res_grad, split_grad), dim=0) res = torch.concat((res, split_weight), dim=0) param_sum += nelement if param_sum == params_packed_index[param_index]: param_sum = 0 if self.dist_init: device_id = self.get_device_id(param_index) if self.rank == device_id: weights.append(res) H.append(res_grad) else: weights.append(res) H.append(res_grad) param_index += 1 self.__update(H, error, weights)
[docs] def get_device_id(self, index): for i, (start, end) in enumerate(self.dindex): if start <= index < end: return i return None