deepmd.jax.train.validation#

Full validation support for the JAX trainer.

Classes#

JAXFullValidator

Run full validation for a single-task JAX energy model.

Module Contents#

class deepmd.jax.train.validation.JAXFullValidator(*, validating_params: dict[str, Any], validation_data: Any, model: deepmd.jax.model.base_model.BaseModel, state_store: dict[str, Any], num_steps: int, rank: int, restart_training: bool, checkpoint_dir: Any = None)[source]#

Bases: deepmd.dpmodel.train.validation.FullValidatorBase

Run full validation for a single-task JAX energy model.

validation_data[source]#
model[source]#
auto_batch_size[source]#
evaluate_all_systems() dict[str, float][source]#

Evaluate every validation system and aggregate metrics.

propagate_error(error_message: str | None) str | None[source]#

Broadcast rank-0 full-validation failures to all JAX processes.

_iter_validation_data_systems() Any[source]#

Yield DeepmdData-like validation systems.

_evaluate_system(data_system: Any) dict[str, tuple[float, float]][source]#

Evaluate one validation system.

_predict_outputs(*, coord: numpy.ndarray, atom_types: numpy.ndarray, box: numpy.ndarray | None, fparam: numpy.ndarray | None, aparam: numpy.ndarray | None, include_virial: bool, natoms: int, nframes: int) dict[str, numpy.ndarray][source]#

Predict energy, force, and virial for the full validation batch.