Source code for deepmd.tf.op._map_flt_nvnmd_grad

#!/usr/bin/env python3

# SPDX-License-Identifier: LGPL-3.0-or-later
from tensorflow.python.framework import (
    ops,
)

from deepmd.tf.env import (
    op_module,
    tf,
)


@ops.RegisterGradient("MapFltNvnmd")
[docs] def _MapFltNvnmdGrad(op, grad): x = op.inputs[0] table = op.inputs[1] table_grad = op.inputs[2] table_info = op.inputs[3] # shape shx = x.shape.as_list() shw = table.shape.as_list() shx = [None if (d == -1) else d for d in shx] shw = [None if (d == -1) else d for d in shw] N = shx[0] D = shx[1] M = shw[1] // 4 # dydx = op_module.map_flt_nvnmd(x, table_grad, tf.zeros_like(table_grad), table_info) dydx = tf.ensure_shape(dydx, [N, D, M]) # calculate dx = op_module.dotmul_flt_nvnmd(dydx, grad) dx = tf.reshape(dx, [-1, D]) # other grad d_table = None d_table_grad = None d_table_info = None return [dx, d_table, d_table_grad, d_table_info]