59 lines
2.6 KiB
Python
59 lines
2.6 KiB
Python
import tensorflow as tf
|
|
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
from .gather_and_scatter_mixin import GatherAndScatterMixin
|
|
|
|
|
|
@onnx_op("GatherElements")
|
|
class GatherElements(GatherAndScatterMixin, BackendHandler):
|
|
|
|
@classmethod
|
|
def version_11(cls, node, **kwargs):
|
|
# GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis that identifies
|
|
# an axis of data (by default, the outer-most axis, that is axis 0). It is an indexing operation that produces its output by
|
|
# indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the
|
|
# same as the shape of indices and consists of one value (gathered from the data) for each element in indices.
|
|
|
|
axis = node.attrs.get("axis", 0)
|
|
data = kwargs["tensor_dict"][node.inputs[0]]
|
|
indices = kwargs["tensor_dict"][node.inputs[1]]
|
|
|
|
# poocess negative axis
|
|
axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)
|
|
|
|
# check are there any indices are out of bounds
|
|
result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
|
|
msg = 'GatherElements indices are out of bounds,'\
|
|
' please double check the indices and retry.'
|
|
with tf.control_dependencies(
|
|
[tf.compat.v1.assert_equal(result, True, message=msg)]):
|
|
# process negative indices
|
|
indices = cls.process_neg_idx_along_axis(data, axis, indices)
|
|
|
|
# adapted from reference implementation in onnx/onnx/backend/test/case/node/gatherelements.py
|
|
if axis == 0:
|
|
axis_perm = tf.range(tf.rank(data))
|
|
data_swaped = data
|
|
index_swaped = indices
|
|
else:
|
|
axis_perm = tf.tensor_scatter_nd_update(tf.range(tf.rank(data)),
|
|
tf.constant([[0], [axis]]),
|
|
tf.constant([axis, 0]))
|
|
data_swaped = tf.transpose(data, perm=axis_perm)
|
|
index_swaped = tf.transpose(indices, perm=axis_perm)
|
|
|
|
idx_tensors_per_axis = tf.meshgrid(*list(
|
|
map(lambda x: tf.range(x, dtype=index_swaped.dtype),
|
|
index_swaped.shape.as_list())),
|
|
indexing='ij')
|
|
idx_tensors_per_axis[0] = index_swaped
|
|
dim_expanded_idx_tensors_per_axis = list(
|
|
map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
|
|
index_expanded = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)
|
|
|
|
gathered = tf.gather_nd(data_swaped, index_expanded)
|
|
y = tf.transpose(gathered, perm=axis_perm)
|
|
|
|
return [y]
|