重命名 pt2tf 为 pt2pb

This commit is contained in:
zhutian
2020-10-14 08:55:07 +08:00
committed by Gitee
parent 324ab60a5d
commit 90ae190559
407 changed files with 0 additions and 0 deletions
@@ -0,0 +1,80 @@
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
@onnx_op("TopK")
@tf_func(tf.nn.top_k)
class TopK(BackendHandler):
@classmethod
def version_1(cls, node, **kwargs):
x = kwargs["tensor_dict"][node.inputs[0]]
x_rank = len(x.get_shape())
axes = list(range(x_rank))
axis = node.attrs.get("axis", -1)
axis = axis if axis >= 0 else axis + x_rank
if axis != x_rank - 1:
pre_perm = [a for a in axes if a != axis] + [axis]
post_perm = axes[:axis] + [x_rank - 1] + axes[axis:x_rank - 1]
x = tf.transpose(x, perm=pre_perm)
values, indices = tf.nn.top_k(x, k=node.attrs["k"])
values = tf.transpose(values, perm=post_perm)
return [values, tf.cast(indices, dtype=tf.int64)]
values, indices = tf.nn.top_k(x, k=node.attrs["k"])
return [values, tf.cast(indices, dtype=tf.int64)]
@classmethod
def version_10(cls, node, **kwargs):
x = kwargs["tensor_dict"][node.inputs[0]]
x_rank = len(x.get_shape())
axes = list(range(x_rank))
axis = node.attrs.get("axis", -1)
axis = axis if axis >= 0 else axis + x_rank
k = kwargs["tensor_dict"][node.inputs[1]][0]
k = tf.cast(k, dtype=tf.int32)
if axis != x_rank - 1:
pre_perm = [a for a in axes if a != axis] + [axis]
post_perm = axes[:axis] + [x_rank - 1] + axes[axis:x_rank - 1]
x = tf.transpose(x, perm=pre_perm)
values, indices = tf.nn.top_k(x, k)
values = tf.transpose(values, perm=post_perm)
return [values, tf.cast(indices, dtype=tf.int64)]
values, indices = tf.nn.top_k(x, k)
return [values, tf.cast(indices, dtype=tf.int64)]
@classmethod
def version_11(cls, node, **kwargs):
x = kwargs["tensor_dict"][node.inputs[0]]
x_rank = len(x.get_shape())
axes = list(range(x_rank))
axis = node.attrs.get("axis", -1)
axis = axis if axis >= 0 else axis + x_rank
largest = node.attrs.get("largest", 1)
sort = node.attrs.get("sorted", 1)
sort = False if sort == 0 else True
k = kwargs["tensor_dict"][node.inputs[1]][0]
k = tf.cast(k, dtype=tf.int32)
if largest == 0:
x = tf.negative(x)
if axis != x_rank - 1:
pre_perm = [a for a in axes if a != axis] + [axis]
post_perm = axes[:axis] + [x_rank - 1] + axes[axis:x_rank - 1]
x = tf.transpose(x, perm=pre_perm)
values, indices = tf.nn.top_k(x, k, sort)
values = tf.transpose(values, perm=post_perm)
else :
values, indices = tf.nn.top_k(x, k, sort)
if largest == 0:
values = tf.negative(values)
return [values, tf.cast(indices, dtype=tf.int64)]