add pt2tf tool
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class BroadcastMixin(object):
|
||||
|
||||
@classmethod
|
||||
def explicit_broadcast(cls, inputs, axis=None, tensor_dict=None):
|
||||
x = inputs[0] if isinstance(inputs[0],
|
||||
tf.Tensor) else tensor_dict[inputs[0]]
|
||||
y = inputs[1] if isinstance(inputs[1],
|
||||
tf.Tensor) else tensor_dict[inputs[1]]
|
||||
|
||||
if np.prod(y.shape) == 1:
|
||||
return y
|
||||
|
||||
if not isinstance(x, tf.Tensor) or not isinstance(y, tf.Tensor):
|
||||
raise ValueError("Targets for explicit broadcasting need to be Tensor.")
|
||||
|
||||
if axis is None:
|
||||
return y
|
||||
|
||||
total_num_dim = len(x.get_shape())
|
||||
if axis < 0:
|
||||
axis += total_num_dim
|
||||
|
||||
if axis + len(y.get_shape()) == total_num_dim:
|
||||
return y
|
||||
|
||||
dims = [axis + i for i in range(len(y.get_shape()))]
|
||||
new_y = y
|
||||
for i in range(total_num_dim):
|
||||
if i not in dims:
|
||||
new_y = tf.expand_dims(new_y, i)
|
||||
return new_y
|
||||
|
||||
@classmethod
|
||||
def limited_broadcast(cls, node, **kwargs):
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
x = tensor_dict[node.inputs[0]]
|
||||
y = tensor_dict[node.inputs[1]]
|
||||
if node.attrs.get("broadcast") == 1:
|
||||
y = cls.explicit_broadcast([x, y], node.attrs.get("axis", None))
|
||||
return [cls.make_tensor_from_onnx_node(node, inputs=[x, y], **kwargs)]
|
||||
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
|
||||
Reference in New Issue
Block a user