add pt2tf tool
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from .gather_and_scatter_mixin import GatherAndScatterMixin
|
||||
|
||||
|
||||
@onnx_op("GatherElements")
|
||||
class GatherElements(GatherAndScatterMixin, BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def version_11(cls, node, **kwargs):
|
||||
# GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis that identifies
|
||||
# an axis of data (by default, the outer-most axis, that is axis 0). It is an indexing operation that produces its output by
|
||||
# indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the
|
||||
# same as the shape of indices and consists of one value (gathered from the data) for each element in indices.
|
||||
|
||||
axis = node.attrs.get("axis", 0)
|
||||
data = kwargs["tensor_dict"][node.inputs[0]]
|
||||
indices = kwargs["tensor_dict"][node.inputs[1]]
|
||||
|
||||
# poocess negative axis
|
||||
axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)
|
||||
|
||||
# check are there any indices are out of bounds
|
||||
result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
|
||||
msg = 'GatherElements indices are out of bounds,'\
|
||||
' please double check the indices and retry.'
|
||||
with tf.control_dependencies(
|
||||
[tf.compat.v1.assert_equal(result, True, message=msg)]):
|
||||
# process negative indices
|
||||
indices = cls.process_neg_idx_along_axis(data, axis, indices)
|
||||
|
||||
# adapted from reference implementation in onnx/onnx/backend/test/case/node/gatherelements.py
|
||||
if axis == 0:
|
||||
axis_perm = tf.range(tf.rank(data))
|
||||
data_swaped = data
|
||||
index_swaped = indices
|
||||
else:
|
||||
axis_perm = tf.tensor_scatter_nd_update(tf.range(tf.rank(data)),
|
||||
tf.constant([[0], [axis]]),
|
||||
tf.constant([axis, 0]))
|
||||
data_swaped = tf.transpose(data, perm=axis_perm)
|
||||
index_swaped = tf.transpose(indices, perm=axis_perm)
|
||||
|
||||
idx_tensors_per_axis = tf.meshgrid(*list(
|
||||
map(lambda x: tf.range(x, dtype=index_swaped.dtype),
|
||||
index_swaped.shape.as_list())),
|
||||
indexing='ij')
|
||||
idx_tensors_per_axis[0] = index_swaped
|
||||
dim_expanded_idx_tensors_per_axis = list(
|
||||
map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
|
||||
index_expanded = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)
|
||||
|
||||
gathered = tf.gather_nd(data_swaped, index_expanded)
|
||||
y = tf.transpose(gathered, perm=axis_perm)
|
||||
|
||||
return [y]
|
||||
Reference in New Issue
Block a user