deepmd.dpmodel.array_api#

Utilities for the array API.

Attributes#

Functions#

xp_swapaxes(→ Array)

xp_take_along_axis(→ Array)

xp_take_first_n(→ Array)

Take the first n elements along dim.

xp_scatter_sum(→ Array)

Reduces all values from the src tensor to the indices specified in the index tensor.

xp_add_at(→ Array)

Adds values to the specified indices of x in place or returns new x (for JAX).

xp_sigmoid(→ Array)

Compute the sigmoid function.

xp_setitem_at(→ Array)

Set items at boolean mask indices.

xp_bincount(→ Array)

Counts the number of occurrences of each value in x.

Module Contents#

deepmd.dpmodel.array_api.Array[source]#
deepmd.dpmodel.array_api.xp_swapaxes(a: Array, axis1: int, axis2: int) Array[source]#
deepmd.dpmodel.array_api.xp_take_along_axis(arr: Array, indices: Array, axis: int) Array[source]#
deepmd.dpmodel.array_api.xp_take_first_n(arr: Array, dim: int, n: int) Array[source]#

Take the first n elements along dim.

For torch tensors, uses torch.index_select so that torch.export does not emit a contiguity guard that would prevent the nall == nloc (no-PBC) case from working. For numpy / jax, uses regular slicing.

deepmd.dpmodel.array_api.xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) Array[source]#

Reduces all values from the src tensor to the indices specified in the index tensor.

This function is similar to PyTorch’s scatter_add and JAX’s scatter_sum. It adds values from src to input at positions specified by index along the given dimension.

deepmd.dpmodel.array_api.xp_add_at(x: Array, indices: Array, values: Array) Array[source]#

Adds values to the specified indices of x in place or returns new x (for JAX).

deepmd.dpmodel.array_api.xp_sigmoid(x: Array) Array[source]#

Compute the sigmoid function.

JAX and PyTorch have optimized sigmoid implementations. See jax-ml/jax#15617

deepmd.dpmodel.array_api.xp_setitem_at(x: Array, mask: Array, values: Array) Array[source]#

Set items at boolean mask indices.

For JAX and PyTorch arrays, returns a new array (non-mutating). For NumPy arrays, modifies in-place and returns the same array.

Parameters:
xArray

The array to modify

maskArray

Boolean mask indicating positions to set

valuesArray

Values to set at masked positions

Returns:
Array

Modified array (new array for JAX/PyTorch, same array for NumPy)

deepmd.dpmodel.array_api.xp_bincount(x: Array, weights: Array | None = None, minlength: int = 0) Array[source]#

Counts the number of occurrences of each value in x.