Source code for deepmd.dpmodel.output_def

# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from enum import (
    IntEnum,
)
from typing import (
    Dict,
    List,
    Tuple,
)


[docs] def check_shape( shape: List[int], def_shape: List[int], ): """Check if the shape satisfies the defined shape.""" assert len(shape) == len(def_shape) if def_shape[-1] == -1: if list(shape[:-1]) != def_shape[:-1]: raise ValueError(f"{shape[:-1]} shape not matching def {def_shape[:-1]}") else: if list(shape) != def_shape: raise ValueError(f"{shape} shape not matching def {def_shape}")
[docs] def check_var(var, var_def): if var_def.atomic: # var.shape == [nf, nloc, *var_def.shape] if len(var.shape) != len(var_def.shape) + 2: raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}") check_shape(list(var.shape[2:]), var_def.shape) else: # var.shape == [nf, *var_def.shape] if len(var.shape) != len(var_def.shape) + 1: raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}") check_shape(list(var.shape[1:]), var_def.shape)
[docs] def model_check_output(cls): """Check if the output of the Model is consistent with the definition. Two methods are assumed to be provided by the Model: 1. Model.output_def that gives the output definition. 2. Model.__call__ that defines the forward path of the model. """ @functools.wraps(cls, updated=()) class wrapper(cls): def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) self.md = self.output_def() def __call__( self, *args, **kwargs, ): ret = cls.__call__(self, *args, **kwargs) for kk in self.md.keys_outp(): dd = self.md[kk] check_var(ret[kk], dd) if dd.reduciable: rk = get_reduce_name(kk) check_var(ret[rk], self.md[rk]) if dd.r_differentiable: dnr, dnc = get_deriv_name(kk) check_var(ret[dnr], self.md[dnr]) if dd.c_differentiable: assert dd.r_differentiable check_var(ret[dnc], self.md[dnc]) return ret return wrapper
[docs] def fitting_check_output(cls): """Check if the output of the Fitting is consistent with the definition. Two methods are assumed to be provided by the Fitting: 1. Fitting.output_def that gives the output definition. 2. Fitting.__call__ defines the forward path of the fitting. """ @functools.wraps(cls, updated=()) class wrapper(cls): def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) self.md = self.output_def() def __call__( self, *args, **kwargs, ): ret = cls.__call__(self, *args, **kwargs) for kk in self.md.keys(): dd = self.md[kk] check_var(ret[kk], dd) return ret return wrapper
[docs] class OutputVariableOperation(IntEnum): """Defines the operation of the output variable."""
[docs] _NONE = 0
"""No operation."""
[docs] REDU = 1
"""Reduce the output variable."""
[docs] DERV_R = 2
"""Derivative w.r.t. coordinates."""
[docs] DERV_C = 4
"""Derivative w.r.t. cell."""
[docs] _SEC_DERV_R = 8
"""Second derivative w.r.t. coordinates."""
[docs] MAG = 16
"""Magnetic output."""
[docs] class OutputVariableCategory(IntEnum): """Defines the category of the output variable."""
[docs] OUT = OutputVariableOperation._NONE
"""Output variable. (e.g. atom energy)"""
[docs] REDU = OutputVariableOperation.REDU
"""Reduced output variable. (e.g. system energy)"""
[docs] DERV_R = OutputVariableOperation.DERV_R
"""Negative derivative w.r.t. coordinates. (e.g. force)"""
[docs] DERV_C = OutputVariableOperation.DERV_C
"""Atomic component of the virial, see PRB 104, 224202 (2021) """
[docs] DERV_C_REDU = OutputVariableOperation.DERV_C | OutputVariableOperation.REDU
"""Virial, the transposed negative gradient with cell tensor times cell tensor, see eq 40 JCP 159, 054801 (2023). """
[docs] DERV_R_DERV_R = OutputVariableOperation.DERV_R | OutputVariableOperation._SEC_DERV_R
"""Hession matrix, the second derivative w.r.t. coordinates."""
[docs] DERV_R_MAG = OutputVariableOperation.DERV_R | OutputVariableOperation.MAG
"""Magnetic part of negative derivative w.r.t. coordinates. (e.g. magnetic force)"""
[docs] DERV_C_MAG = OutputVariableOperation.DERV_C | OutputVariableOperation.MAG
"""Magnetic part of atomic component of the virial."""
[docs] class OutputVariableDef: """Defines the shape and other properties of the one output variable. It is assume that the fitting network output variables for each local atom. This class defines one output variable, including its name, shape, reducibility and differentiability. Parameters ---------- name Name of the output variable. Notice that the xxxx_redu, xxxx_derv_c, xxxx_derv_r are reserved names that should not be used to define variables. shape The shape of the variable. e.g. energy should be [1], dipole should be [3], polarizabilty should be [3,3]. reduciable If the variable is reduced. r_differentiable If the variable is differentiated with respect to coordinates of atoms. Only reduciable variable are differentiable. Negative derivative w.r.t. coordinates will be calcualted. (e.g. force) c_differentiable If the variable is differentiated with respect to the cell tensor (pbc case). Only reduciable variable are differentiable. Virial, the transposed negative gradient with cell tensor times cell tensor, will be calculated, see eq 40 JCP 159, 054801 (2023). atomic : bool If the variable is defined for each atom. category : int The category of the output variable. r_hessian : bool If hessian is requred magnetic : bool If the derivatives of variable have magnetic parts. """ def __init__( self, name: str, shape: List[int], reduciable: bool = False, r_differentiable: bool = False, c_differentiable: bool = False, atomic: bool = True, category: int = OutputVariableCategory.OUT.value, r_hessian: bool = False, magnetic: bool = False, ): self.name = name self.shape = list(shape) # jit doesn't support math.prod(self.shape) self.output_size = 1 len_shape = len(self.shape) for i in range(len_shape): self.output_size *= self.shape[i] self.atomic = atomic self.reduciable = reduciable self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable if self.c_differentiable and not self.r_differentiable: raise ValueError("c differentiable requires r_differentiable") if self.reduciable and not self.atomic: raise ValueError("a reduciable variable should be atomic") self.category = category self.r_hessian = r_hessian self.magnetic = magnetic if self.r_hessian: if not self.reduciable: raise ValueError("only reduciable variable can calculate hessian") if not self.r_differentiable: raise ValueError("only r_differentiable variable can calculate hessian") @property
[docs] def size(self): return self.output_size
[docs] class FittingOutputDef: """Defines the shapes and other properties of the fitting network outputs. It is assume that the fitting network output variables for each local atom. This class defines all the outputs. Parameters ---------- var_defs List of output variable definitions. """ def __init__( self, var_defs: List[OutputVariableDef], ): self.var_defs = {vv.name: vv for vv in var_defs}
[docs] def __getitem__( self, key: str, ) -> OutputVariableDef: return self.var_defs[key]
[docs] def get_data(self) -> Dict[str, OutputVariableDef]: return self.var_defs
[docs] def keys(self): return self.var_defs.keys()
[docs] class ModelOutputDef: """Defines the shapes and other properties of the model outputs. The model reduce and differentiate fitting outputs if applicable. If a variable is named by foo, then the reduced variable is called foo_redu, the derivative w.r.t. coordinates is called foo_derv_r and the derivative w.r.t. cell is called foo_derv_c. Parameters ---------- fit_defs Definition for the fitting net output """ def __init__( self, fit_defs: FittingOutputDef, ): self.def_outp = fit_defs self.def_redu = do_reduce(self.def_outp.get_data()) self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp.get_data()) self.def_hess_r, _ = do_derivative(self.def_derv_r) self.def_derv_c_redu = do_reduce(self.def_derv_c) self.def_mask = do_mask(self.def_outp.get_data()) self.var_defs: Dict[str, OutputVariableDef] = {} for ii in [ self.def_outp.get_data(), self.def_redu, self.def_derv_c, self.def_derv_r, self.def_derv_c_redu, self.def_hess_r, self.def_mask, ]: self.var_defs.update(ii)
[docs] def __getitem__( self, key: str, ) -> OutputVariableDef: return self.var_defs[key]
[docs] def get_data( self, key: str, ) -> Dict[str, OutputVariableDef]: return self.var_defs
[docs] def keys(self): return self.var_defs.keys()
[docs] def keys_outp(self): return self.def_outp.keys()
[docs] def keys_redu(self): return self.def_redu.keys()
[docs] def keys_derv_r(self): return self.def_derv_r.keys()
[docs] def keys_hess_r(self): return self.def_hess_r.keys()
[docs] def keys_derv_c(self): return self.def_derv_c.keys()
[docs] def keys_derv_c_redu(self): return self.def_derv_c_redu.keys()
[docs] def get_reduce_name(name: str) -> str: return name + "_redu"
[docs] def get_deriv_name(name: str) -> Tuple[str, str]: return name + "_derv_r", name + "_derv_c"
[docs] def get_deriv_name_mag(name: str) -> Tuple[str, str]: return name + "_derv_r_mag", name + "_derv_c_mag"
[docs] def get_hessian_name(name: str) -> str: return name + "_derv_r_derv_r"
[docs] def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int: """Apply an operation to the category of a variable definition. Parameters ---------- var_def : OutputVariableDef The variable definition. op : OutputVariableOperation The operation to be applied. Returns ------- int The new category of the variable definition. Raises ------ ValueError If the operation has been applied to the variable definition, and exceed the maximum limitation. """ if op == OutputVariableOperation.REDU or op == OutputVariableOperation.DERV_C: if check_operation_applied(var_def, op): raise ValueError(f"operation {op} has been applied") elif op == OutputVariableOperation.DERV_R: if check_operation_applied(var_def, OutputVariableOperation.DERV_R): op = OutputVariableOperation._SEC_DERV_R if check_operation_applied(var_def, OutputVariableOperation._SEC_DERV_R): raise ValueError(f"operation {op} has been applied twice") else: raise ValueError(f"operation {op} not supported") return var_def.category | op.value
[docs] def check_operation_applied( var_def: OutputVariableDef, op: OutputVariableOperation ) -> bool: """Check if a operation has been applied to a variable definition. Parameters ---------- var_def : OutputVariableDef The variable definition. op : OutputVariableOperation The operation to be checked. Returns ------- bool True if the operation has been applied, False otherwise. """ return var_def.category & op.value == op.value
[docs] def do_reduce( def_outp_data: Dict[str, OutputVariableDef], ) -> Dict[str, OutputVariableDef]: def_redu: Dict[str, OutputVariableDef] = {} for kk, vv in def_outp_data.items(): if vv.reduciable: rk = get_reduce_name(kk) def_redu[rk] = OutputVariableDef( rk, vv.shape, reduciable=False, r_differentiable=False, c_differentiable=False, atomic=False, category=apply_operation(vv, OutputVariableOperation.REDU), ) return def_redu
[docs] def do_mask( def_outp_data: Dict[str, OutputVariableDef], ) -> Dict[str, OutputVariableDef]: def_mask: Dict[str, OutputVariableDef] = {} # for deep eval when has atomic mask def_mask["mask"] = OutputVariableDef( name="mask", shape=[1], reduciable=False, r_differentiable=False, c_differentiable=False, ) for kk, vv in def_outp_data.items(): if vv.magnetic: # for deep eval when has atomic mask for magnetic atoms def_mask["mask_mag"] = OutputVariableDef( name="mask_mag", shape=[1], reduciable=False, r_differentiable=False, c_differentiable=False, ) return def_mask
[docs] def do_derivative( def_outp_data: Dict[str, OutputVariableDef], ) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]: def_derv_r: Dict[str, OutputVariableDef] = {} def_derv_c: Dict[str, OutputVariableDef] = {} for kk, vv in def_outp_data.items(): rkr, rkc = get_deriv_name(kk) rkrm, rkcm = get_deriv_name_mag(kk) if vv.r_differentiable: def_derv_r[rkr] = OutputVariableDef( rkr, vv.shape + [3], # noqa: RUF005 reduciable=False, r_differentiable=( vv.r_hessian and vv.category == OutputVariableCategory.OUT.value ), c_differentiable=False, atomic=True, category=apply_operation(vv, OutputVariableOperation.DERV_R), ) if vv.magnetic: def_derv_r[rkrm] = OutputVariableDef( rkrm, vv.shape + [3], # noqa: RUF005 reduciable=False, r_differentiable=( vv.r_hessian and vv.category == OutputVariableCategory.OUT.value ), c_differentiable=False, atomic=True, category=apply_operation(vv, OutputVariableOperation.DERV_R), magnetic=True, ) if vv.c_differentiable: assert vv.r_differentiable def_derv_c[rkc] = OutputVariableDef( rkc, vv.shape + [9], # noqa: RUF005 reduciable=True, r_differentiable=False, c_differentiable=False, atomic=True, category=apply_operation(vv, OutputVariableOperation.DERV_C), ) if vv.magnetic: def_derv_r[rkcm] = OutputVariableDef( rkcm, vv.shape + [9], # noqa: RUF005 reduciable=True, r_differentiable=False, c_differentiable=False, atomic=True, category=apply_operation(vv, OutputVariableOperation.DERV_C), magnetic=True, ) return def_derv_r, def_derv_c