Source code for deepmd.tf.op._matmul_fitnet_nvnmd_grad

#!/usr/bin/env python3

# SPDX-License-Identifier: LGPL-3.0-or-later
from tensorflow.python.framework import (
    ops,
)

from deepmd.tf.env import (
    op_module,
    tf,
)


@ops.RegisterGradient("MatmulFitnetNvnmd")
[docs] def _MatmulFitnetNvnmdGrad(op, grad): x = op.inputs[0] w = op.inputs[1] nbitx = op.get_attr("nbitx") nbitw = op.get_attr("nbitw") normw = op.get_attr("normw") dx = op_module.matmul_fitnet_nvnmd(grad, tf.transpose(w), nbitx, nbitw, normw) dw = tf.matmul(tf.transpose(x), grad) return [dx, dw]