deepmd.dpmodel.array_api#
Utilities for the array API.
Attributes#
Functions#
| |
| |
| Take the first n elements along dim. |
| Reduces all values from the src tensor to the indices specified in the index tensor. |
| Adds values to the specified indices of x in place or returns new x (for JAX). |
| Compute the sigmoid function. |
| Set items at boolean mask indices. |
| Counts the number of occurrences of each value in x. |
Module Contents#
- deepmd.dpmodel.array_api.xp_take_first_n(arr: Array, dim: int, n: int) Array[source]#
Take the first n elements along dim.
For torch tensors, uses
torch.index_selectso thattorch.exportdoes not emit a contiguity guard that would prevent thenall == nloc(no-PBC) case from working. For numpy / jax, uses regular slicing.
- deepmd.dpmodel.array_api.xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) Array[source]#
Reduces all values from the src tensor to the indices specified in the index tensor.
This function is similar to PyTorch’s scatter_add and JAX’s scatter_sum. It adds values from src to input at positions specified by index along the given dimension.
- deepmd.dpmodel.array_api.xp_add_at(x: Array, indices: Array, values: Array) Array[source]#
Adds values to the specified indices of x in place or returns new x (for JAX).
- deepmd.dpmodel.array_api.xp_sigmoid(x: Array) Array[source]#
Compute the sigmoid function.
JAX and PyTorch have optimized sigmoid implementations. See jax-ml/jax#15617