Source code for deepmd.tf.op._mul_flt_nvnmd_grad

#!/usr/bin/env python3

# SPDX-License-Identifier: LGPL-3.0-or-later
from tensorflow.python.framework import (
    ops,
)

from deepmd.tf.env import (
    op_module,
    tf,
)


@ops.RegisterGradient("MulFltNvnmd")
[docs] def _MulFltNvnmdGrad(op, grad): x = op.inputs[0] w = op.inputs[1] # transpose for 2-dimension and 3-dimension multiplication shx = x.shape.as_list() shw = w.shape.as_list() shx = [None if (d == -1) else d for d in shx] shw = [None if (d == -1) else d for d in shw] if shx[-1] == shw[-1]: dx = op_module.mul_flt_nvnmd(w, grad) dw = op_module.mul_flt_nvnmd(x, grad) else: # shx[-1] == 1, shw[-1] == M dx = op_module.dotmul_flt_nvnmd(w, grad) dw = op_module.mul_flt_nvnmd(x, grad) # add shapes dx = tf.ensure_shape(dx, shx) dw = tf.ensure_shape(dw, shw) return [dx, dw]