deepmd.jax.jax2tf.serialization

deepmd.jax.jax2tf.serialization#

JAX/jax2tf SavedModel export.

The .savedmodel suffix is the JAX SavedModel artifact used by the JAX C++ inference path. It is intentionally different from the TF2 eager .savedmodeltf artifact: the model body below must pass through jax2tf.convert so TensorFlow stores XlaCallModule nodes. Do not replace this module with the TF2 SavedModel exporter unless the file suffix and C++ loader contract are changed together.

Functions#

deserialize_to_file(→ None)

Deserialize the dictionary to a JAX/jax2tf SavedModel.

Module Contents#

deepmd.jax.jax2tf.serialization.deserialize_to_file(model_file: str, data: dict) None[source]#

Deserialize the dictionary to a JAX/jax2tf SavedModel.