deepmd.jax.train.validation#
Full validation support for the JAX trainer.
Classes#
Run full validation for a single-task JAX energy model. |
Module Contents#
- class deepmd.jax.train.validation.JAXFullValidator(*, validating_params: dict[str, Any], validation_data: Any, model: deepmd.jax.model.base_model.BaseModel, state_store: dict[str, Any], num_steps: int, rank: int, restart_training: bool, checkpoint_dir: Any = None)[source]#
Bases:
deepmd.dpmodel.train.validation.FullValidatorBaseRun full validation for a single-task JAX energy model.
- evaluate_all_systems() dict[str, float][source]#
Evaluate every validation system and aggregate metrics.
- propagate_error(error_message: str | None) str | None[source]#
Broadcast rank-0 full-validation failures to all JAX processes.
- _evaluate_system(data_system: Any) dict[str, tuple[float, float]][source]#
Evaluate one validation system.
- _predict_outputs(*, coord: numpy.ndarray, atom_types: numpy.ndarray, box: numpy.ndarray | None, fparam: numpy.ndarray | None, aparam: numpy.ndarray | None, include_virial: bool, natoms: int, nframes: int) dict[str, numpy.ndarray][source]#
Predict energy, force, and virial for the full validation batch.