Source code for deepmd.tf.op._prod_force_se_r_grad
#!/usr/bin/env python3
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Gradients for prod force."""
from tensorflow.python.framework import (
ops,
)
from deepmd.tf.env import (
op_grads_module,
)
@ops.RegisterGradient("ProdForceSeR")
[docs]
def _prod_force_se_a_grad_cc(op, grad):
net_grad = op_grads_module.prod_force_se_r_grad(
grad, op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[3]
)
return [net_grad, None, None, None]