Source code for deepmd.utils.env_mat_stat

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
    ABC,
    abstractmethod,
)
from collections import (
    defaultdict,
)
from typing import (
    Dict,
    Iterator,
    List,
    Optional,
)

import numpy as np

from deepmd.utils.path import (
    DPPath,
)

[docs] log = logging.getLogger(__name__)
[docs] class StatItem: """A class to store the statistics of the environment matrix. Parameters ---------- number : int The total size of given array. sum : float The sum value of the matrix. squared_sum : float The sum squared value of the matrix. """ def __init__(self, number: int = 0, sum: float = 0, squared_sum: float = 0) -> None: self.number = number self.sum = sum self.squared_sum = squared_sum
[docs] def __add__(self, other: "StatItem") -> "StatItem": return StatItem( number=self.number + other.number, sum=self.sum + other.sum, squared_sum=self.squared_sum + other.squared_sum, )
[docs] def compute_avg(self, default: float = 0) -> float: """Compute the average of the environment matrix. Parameters ---------- default : float, optional The default value of the average, by default 0. Returns ------- float The average of the environment matrix. """ if self.number == 0: return default return self.sum / self.number
[docs] def compute_std(self, default: float = 1e-1, protection: float = 1e-2) -> float: """Compute the standard deviation of the environment matrix. Parameters ---------- default : float, optional The default value of the standard deviation, by default 1e-1. protection : float, optional The protection value for the standard deviation, by default 1e-2. Returns ------- float The standard deviation of the environment matrix. """ if self.number == 0: return default val = np.sqrt( self.squared_sum / self.number - np.multiply(self.sum / self.number, self.sum / self.number) ) if np.abs(val) < protection: val = protection return val
[docs] class EnvMatStat(ABC): """A base class to store and calculate the statistics of the environment matrix.""" def __init__(self) -> None: super().__init__() self.stats = defaultdict(StatItem)
[docs] def compute_stats(self, data: List[Dict[str, np.ndarray]]) -> None: """Compute the statistics of the environment matrix. Parameters ---------- data : List[Dict[str, np.ndarray]] The environment matrix. """ if len(self.stats) > 0: raise ValueError("The statistics has already been computed.") for iter_stats in self.iter(data): for kk in iter_stats: self.stats[kk] += iter_stats[kk]
@abstractmethod
[docs] def iter(self, data: List[Dict[str, np.ndarray]]) -> Iterator[Dict[str, StatItem]]: """Get the iterator of the environment matrix. Parameters ---------- data : List[Dict[str, np.ndarray]] The environment matrix. Yields ------ Dict[str, StatItem] The statistics of the environment matrix. """
[docs] def save_stats(self, path: DPPath) -> None: """Save the statistics of the environment matrix. Parameters ---------- path : DPPath The path to save the statistics of the environment matrix. """ if len(self.stats) == 0: raise ValueError("The statistics hasn't been computed.") for kk, vv in self.stats.items(): path.mkdir(parents=True, exist_ok=True) (path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum]))
[docs] def load_stats(self, path: DPPath) -> None: """Load the statistics of the environment matrix. Parameters ---------- path : DPPath The path to load the statistics of the environment matrix. """ if len(self.stats) > 0: raise ValueError("The statistics has already been computed.") for kk in path.glob("*"): arr = kk.load_numpy() self.stats[kk.name] = StatItem( number=arr[0], sum=arr[1], squared_sum=arr[2], )
[docs] def load_or_compute_stats( self, data: List[Dict[str, np.ndarray]], path: Optional[DPPath] = None ) -> None: """Load the statistics of the environment matrix if it exists, otherwise compute and save it. Parameters ---------- path : DPPath The path to load the statistics of the environment matrix. data : List[Dict[str, np.ndarray]] The environment matrix. """ if path is not None and path.is_dir(): self.load_stats(path) log.info(f"Load stats from {path}.") else: self.compute_stats(data) if path is not None: self.save_stats(path) log.info(f"Save stats to {path}.")
[docs] def get_avg(self, default: float = 0) -> Dict[str, float]: """Get the average of the environment matrix. Parameters ---------- default : float, optional The default value of the average, by default 0. Returns ------- Dict[str, float] The average of the environment matrix. """ return {kk: vv.compute_avg(default=default) for kk, vv in self.stats.items()}
[docs] def get_std( self, default: float = 1e-1, protection: float = 1e-2 ) -> Dict[str, float]: """Get the standard deviation of the environment matrix. Parameters ---------- default : float, optional The default value of the standard deviation, by default 1e-1. protection : float, optional The protection value for the standard deviation, by default 1e-2. Returns ------- Dict[str, float] The standard deviation of the environment matrix. """ return { kk: vv.compute_std(default=default, protection=protection) for kk, vv in self.stats.items() }