Source code for deepmd.jax.model.dp_zbl_model
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)
from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP
from deepmd.jax.atomic_model.linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jax,
jnp,
)
from deepmd.jax.model.base_model import (
BaseModel,
forward_common_atomic,
)
@BaseModel.register("zbl")
@flax_module
[docs]
class DPZBLModel(DPZBLModelDP):
[docs]
def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = DPZBLLinearEnergyAtomicModel.deserialize(value.serialize())
return super().__setattr__(name, value)
[docs]
def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: jnp.ndarray | None = None,
fparam: jnp.ndarray | None = None,
aparam: jnp.ndarray | None = None,
do_atomic_virial: bool = False,
extended_coord_corr: jnp.ndarray | None = None,
comm_dict: dict | None = None,
charge_spin: jnp.ndarray | None = None,
) -> dict[str, jnp.ndarray]:
del comm_dict # JAX path has no MPI ghost exchange
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
extended_coord_corr=extended_coord_corr,
charge_spin=charge_spin,
)