Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/handlers/backend/scatter_elements.py
T
2020-10-14 08:55:07 +08:00

65 lines
2.7 KiB
Python

import tensorflow as tf
from onnx_tf.common.tf_helper import tf_shape
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from .gather_and_scatter_mixin import GatherAndScatterMixin
@onnx_op("ScatterElements")
class ScatterElements(GatherAndScatterMixin, BackendHandler):
@classmethod
def version_11(cls, node, **kwargs):
axis = node.attrs.get("axis", 0)
data = kwargs["tensor_dict"][node.inputs[0]]
indices = kwargs["tensor_dict"][node.inputs[1]]
updates = kwargs["tensor_dict"][node.inputs[2]]
# poocess negative axis
axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)
# check are there any indices are out of bounds
result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
msg = 'ScatterElements indices are out of bounds, please double check the indices and retry.'
with tf.control_dependencies([tf.compat.v1.assert_equal(result, True, message=msg)]):
# process negative indices
indices = cls.process_neg_idx_along_axis(data, axis, indices)
# Calculate shape of the tensorflow version of indices tensor.
sparsified_dense_idx_shape = tf_shape(updates)
# Move on to convert ONNX indices to tensorflow indices in 2 steps:
#
# Step 1:
# What would the index tensors look like if updates are all
# dense? In other words, produce a coordinate tensor for updates:
#
# coordinate[i, j, k ...] = [i, j, k ...]
# where the shape of "coordinate" tensor is same as that of updates.
#
# Step 2:
# But the coordinate tensor needs some correction because coord
# vector at position axis is wrong (since we assumed update is dense,
# but it is not at the axis specified).
# So we update coordinate vector tensor elements at psotion=axis with
# the sparse coordinate indices.
idx_tensors_per_axis = tf.meshgrid(*list(
map(lambda x: tf.range(x, dtype=tf.dtypes.int64),
sparsified_dense_idx_shape)),
indexing='ij')
idx_tensors_per_axis[axis] = indices
dim_expanded_idx_tensors_per_axis = list(
map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
coordinate = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)
# Now the coordinate tensor is in the shape
# [updates.shape, updates.rank]
# we need it to flattened into the shape:
# [product(updates.shape), updates.rank]
indices = tf.reshape(coordinate, [-1, tf.rank(data)])
updates = tf.reshape(updates, [-1])
return [tf.tensor_scatter_nd_update(data, indices, updates)]