Files
2020-10-14 08:55:07 +08:00

46 lines
1.4 KiB
Python

import numpy as np
import tensorflow as tf
class BroadcastMixin(object):
@classmethod
def explicit_broadcast(cls, inputs, axis=None, tensor_dict=None):
x = inputs[0] if isinstance(inputs[0],
tf.Tensor) else tensor_dict[inputs[0]]
y = inputs[1] if isinstance(inputs[1],
tf.Tensor) else tensor_dict[inputs[1]]
if np.prod(y.shape) == 1:
return y
if not isinstance(x, tf.Tensor) or not isinstance(y, tf.Tensor):
raise ValueError("Targets for explicit broadcasting need to be Tensor.")
if axis is None:
return y
total_num_dim = len(x.get_shape())
if axis < 0:
axis += total_num_dim
if axis + len(y.get_shape()) == total_num_dim:
return y
dims = [axis + i for i in range(len(y.get_shape()))]
new_y = y
for i in range(total_num_dim):
if i not in dims:
new_y = tf.expand_dims(new_y, i)
return new_y
@classmethod
def limited_broadcast(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
y = tensor_dict[node.inputs[1]]
if node.attrs.get("broadcast") == 1:
y = cls.explicit_broadcast([x, y], node.attrs.get("axis", None))
return [cls.make_tensor_from_onnx_node(node, inputs=[x, y], **kwargs)]
return [cls.make_tensor_from_onnx_node(node, **kwargs)]