Source code for deepmd.dpmodel.utils.serialization

# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from datetime import (
    datetime,
)
from typing import (
    Callable,
)

import h5py

try:
    from deepmd._version import version as __version__
except ImportError:
[docs] __version__ = "unknown"
[docs] def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False): """Traverse a model dict and call callback on each variable. Parameters ---------- model_obj : object The model object to traverse. callback : callable The callback function to call on each variable. is_variable : bool, optional Whether the current node is a variable. Returns ------- object The model object after traversing. """ if isinstance(model_obj, dict): for kk, vv in model_obj.items(): model_obj[kk] = traverse_model_dict( vv, callback, is_variable=is_variable or kk == "@variables" ) elif isinstance(model_obj, list): for ii, vv in enumerate(model_obj): model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable) elif model_obj is None: return model_obj elif is_variable: model_obj = callback(model_obj) return model_obj
[docs] class Counter: """A callable counter. Examples -------- >>> counter = Counter() >>> counter() 0 >>> counter() 1 """ def __init__(self): self.count = -1
[docs] def __call__(self): self.count += 1 return self.count
[docs] def save_dp_model(filename: str, model_dict: dict) -> None: """Save a DP model to a file in the native format. Parameters ---------- filename : str The filename to save to. model_dict : dict The model dict to save. """ model_dict = model_dict.copy() variable_counter = Counter() with h5py.File(filename, "w") as f: model_dict = traverse_model_dict( model_dict, lambda x: f.create_dataset( f"variable_{variable_counter():04d}", data=x ).name, ) save_dict = { "software": "deepmd-kit", "version": __version__, # use UTC+0 time "time": str(datetime.utcnow()), **model_dict, } f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
[docs] def load_dp_model(filename: str) -> dict: """Load a DP model from a file in the native format. Parameters ---------- filename : str The filename to load from. Returns ------- dict The loaded model dict, including meta information. """ with h5py.File(filename, "r") as f: model_dict = json.loads(f.attrs["json"]) model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) return model_dict