deepmd.pt_expt.utils.serialization#
Attributes#
Functions#
| Neutralise shape-guard assertion nodes in a spin model's exported graph. |
| Convert numpy arrays in a model dict to JSON-serializable lists. |
| Convert JSON-serialized numpy arrays back to np.ndarray. |
| |
| Return |
| Build trivial-but-valid comm tensors for tracing the with-comm variant. |
| Create sample inputs for tracing forward_lower. |
| Build dynamic shape specifications for torch.export. |
| Collect metadata from the model for C++ inference. |
| Serialize a .pte or .pt2 model file to a dictionary. |
| Serialize a .pte model file to a dictionary. |
| Serialize a .pt2 model file to a dictionary. |
| Deserialize a dictionary to a .pte or .pt2 model file. |
| Common logic: build model, trace, export. |
| Deserialize a dictionary to a .pte model file. |
| Deserialize a dictionary to a .pt2 model file (AOTInductor). |
Module Contents#
- deepmd.pt_expt.utils.serialization._strip_shape_assertions(graph_module: torch.nn.Module) None[source]#
Neutralise shape-guard assertion nodes in a spin model’s exported graph.
torch.exportinsertsaten._assert_scalarnodes for symbolic shape relationships discovered during tracing. For the spin model, the atom- doubling logic creates slice patterns that depend on(nall - nloc), producing guards likeNe(nall, nloc). These guards are spurious: the model computes correct results even whennall == nloc(NoPBC, no ghost atoms).This function is only called for spin models (guarded by
if is_spinin_trace_and_export). The assertion messages use opaque symbolic variable names (e.g.Ne(s22, s96)) rather than human-readable names, so filtering by message content is not reliable. Sinceprefer_deferred_runtime_asserts_over_guards=Trueconverts all shape guards into these deferred assertions, and the only shape relationships in the spin model involve nall/nloc, neutralising all of them is safe in this context.We replace each assertion’s condition with
Truerather than erasing the node; erasing nodes can disturb the FX graph structure and produce NaN gradients on some Python/torch versions.
- deepmd.pt_expt.utils.serialization._numpy_to_json_serializable(model_obj: dict) dict[source]#
Convert numpy arrays in a model dict to JSON-serializable lists.
- deepmd.pt_expt.utils.serialization._json_to_numpy(model_obj: dict) dict[source]#
Convert JSON-serialized numpy arrays back to np.ndarray.
- deepmd.pt_expt.utils.serialization._needs_with_comm_artifact(model: torch.nn.Module) bool[source]#
Return
Trueif the model needs a “with-comm” AOTI artifact compiled.The with-comm artifact carries the per-layer
deepmd_export::border_opcalls that exchange node-embedding tensors across MPI ranks. Multi-rank LAMMPS dispatches to it when the descriptor’s message passing extends across rank boundaries (i.e. layers consume neighbour features that live on a different rank). Non-GNN descriptors and GNN descriptors withuse_loc_mapping=Truekeep all per-layer messaging local to each rank’s owned atoms; they need only the regular artifact.Delegates to
descriptor.has_message_passing_across_ranks(), which descriptor classes implement explicitly. ReturnsFalsedefensively when the model has no single descriptor (linear/zbl/frozen) or when the method is somehow missing or raises.
- deepmd.pt_expt.utils.serialization._TRACE_SENDLIST_KEEPALIVE: list[numpy.ndarray] = [][source]#
- deepmd.pt_expt.utils.serialization._make_comm_sample_inputs(nloc: int, nghost: int, device: torch.device) tuple[torch.Tensor, Ellipsis][source]#
Build trivial-but-valid comm tensors for tracing the with-comm variant.
Phase 0 finding: tracing with
nswap == 0causes the dim to specialize, so we must usenswap >= 1. We usenswap == 1with a single self-send swap whose sendlist points tonghostlocal atoms (the actual indices don’t matter for the trace — only the validity of the pointer matters;border_opis opaque totorch.exportvia thedeepmd_export::border_opwrapper).Returns
(send_list, send_proc, recv_proc, send_num, recv_num, communicator, nlocal_ts, nghost_ts)— 8 tensors, matching the canonical positional order offorward_common_lower_exportable_with_comm.
- deepmd.pt_expt.utils.serialization._make_sample_inputs(model: torch.nn.Module, nframes: int = 1, nloc: int = 7, has_spin: bool = False) tuple[torch.Tensor, Ellipsis][source]#
Create sample inputs for tracing forward_lower.
- Parameters:
- model
torch.nn.Module The pt_expt model (must have get_rcut, get_sel, get_type_map, etc.).
- nframes
int Number of frames.
- nloc
int Number of local atoms.
- has_spinbool
If True, create an extended spin tensor and return 7 tensors.
- model
- Returns:
tuple(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin) or (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam, charge_spin) when has_spin.
- deepmd.pt_expt.utils.serialization._build_dynamic_shapes(*sample_inputs: torch.Tensor | None, has_spin: bool = False, with_comm_dict: bool = False, model_nnei: int = 1) tuple[source]#
Build dynamic shape specifications for torch.export.
Marks nframes, nloc, nall and nnei as dynamic dimensions so the exported program handles arbitrary frame, atom and neighbor counts.
When
with_comm_dictis True, 8 additional comm tensors are appended to the returned tuple — matching the positional order offorward_common_lower_exportable_with_comm.nswapis the only dynamic dim among them; the rest are scalar or fixed-size.- Parameters:
- *sample_inputs
torch.Tensor|None Sample inputs: 6 tensors (non-spin) or 7 (spin), optionally followed by 8 comm tensors when
with_comm_dict.- has_spinbool
Whether the inputs include an extended_spin tensor.
- with_comm_dictbool
Whether the inputs include the 8 comm tensors.
- model_nnei
int The model’s sum(sel). Used as the min for the dynamic nnei dim.
- Returns a tuple (not dict) to match positional args of the make_fx
- traced module, whose arg names may have suffixes like ``_1``.
- *sample_inputs
- deepmd.pt_expt.utils.serialization._collect_metadata(model: torch.nn.Module, is_spin: bool = False) dict[source]#
Collect metadata from the model for C++ inference.
This metadata is stored as
metadata.jsonin both .pt2 and .pte archives. Training config is stored separately inmodel_def_script.json. C++ reads flat JSON fields because compiling model API methods as AOTInductor entry points is impractical (~12 s per trivial function) and string outputs (get_type_map) cannot be expressed as tensor I/O.The
fitting_output_defslist is also included so thatModelOutputDefcan be reconstructed without loading the full model.
- deepmd.pt_expt.utils.serialization.serialize_from_file(model_file: str) dict[source]#
Serialize a .pte or .pt2 model file to a dictionary.
Reads the model dict stored in the model archive.
- deepmd.pt_expt.utils.serialization._serialize_from_file_pte(model_file: str) dict[source]#
Serialize a .pte model file to a dictionary.
- deepmd.pt_expt.utils.serialization._serialize_from_file_pt2(model_file: str) dict[source]#
Serialize a .pt2 model file to a dictionary.
Reads the model dict stored in the
model/extra/directory of the.pt2ZIP archive.
- deepmd.pt_expt.utils.serialization.deserialize_to_file(model_file: str, data: dict, model_json_override: dict | None = None, do_atomic_virial: bool = False) None[source]#
Deserialize a dictionary to a .pte or .pt2 model file.
Builds a pt_expt model from the dict, traces it via make_fx, exports with dynamic shapes, and saves.
- Parameters:
- model_file
str The model file to be saved (.pte or .pt2).
- data
dict The dictionary to be deserialized (same format as dpmodel’s serialize output, with “model” and optionally “model_def_script” keys). If
data["model_def_script"]is present, it is embedded in the output so that--use-pretrain-scriptcan extract descriptor/fitting params at finetune time.- model_json_override
dictorNone If provided, this dict is stored in model.json instead of
data. Used bydp compressto store the compressed model dict while tracing the uncompressed model (make_fx cannot trace custom ops).- do_atomic_virialbool
If True, export with per-atom virial correction (3 extra backward passes, ~2.5x slower). Default False for best performance.
- model_file
- deepmd.pt_expt.utils.serialization._trace_and_export(data: dict, model_json_override: dict | None = None, with_comm_dict: bool = False, do_atomic_virial: bool = False) tuple[source]#
Common logic: build model, trace, export.
- Parameters:
- data
Serialized model dict (with “model” and optionally “model_def_script” keys).
- model_json_override
Optional alternate dict to embed as model.json (used by
dp compressto store the compressed model dict while tracing the uncompressed one).- with_comm_dict
If True, trace
forward_common_lower_exportable_with_comminstead of the regular variant. The resulting exported program accepts 8 additional positional comm tensors (send_list,send_proc,recv_proc,send_num,recv_num,communicator,nlocal,nghost) used by the pt_expt Repflow/Repformer override to drive MPI ghost-atom exchange. Only valid for models that need cross-rank ghost-feature exchange (see_needs_with_comm_artifact).- do_atomic_virial
If True, the traced graph computes per-atom virial (extra autograd.grad backward passes); off by default to keep .pt2 inference fast. Mirrors PR #5407 in upstream master.
- Returns:
tuple(exported, metadata, data_for_json, output_keys).
- deepmd.pt_expt.utils.serialization._deserialize_to_file_pte(model_file: str, data: dict, model_json_override: dict | None = None, do_atomic_virial: bool = False) None[source]#
Deserialize a dictionary to a .pte model file.
- deepmd.pt_expt.utils.serialization._deserialize_to_file_pt2(model_file: str, data: dict, model_json_override: dict | None = None, do_atomic_virial: bool = False) None[source]#
Deserialize a dictionary to a .pt2 model file (AOTInductor).
Uses torch._inductor.aoti_compile_and_package to compile the exported program into a .pt2 package (ZIP archive with compiled shared libraries), then embeds metadata into the archive.
For models whose descriptor reports
has_message_passing_across_ranks() == True(DPA2, DPA3 withuse_loc_mapping=False, or hybrids wrapping such children), compiles a SECONDwith-commartifact and packs it alongside the regular one. Thewith-commvariant accepts comm-dict tensors as additional positional inputs and drives MPI ghost-atom exchange viadeepmd_export::border_op. The C++DeepPotPTExptloader picks the artifact based on the LAMMPS rank count at runtime.- Layout inside the .pt2 ZIP (PyTorch 2.11 strict layout):
regular → artifact at
model/(AOTInductor’s own layout) with-comm →model/extra/forward_lower_with_comm.pt2(nested ZIP) metadata →model/extra/metadata.jsonwithhas_comm_artifactflag. The C++ reader matches by/-delimited suffix so the legacy root-levelextra/layout still loads.
Old .pt2 files (pre-this-change) lack
has_comm_artifactso the C++ loader must default toFalsewhen the field is missing.