add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,25 @@
import copy
from .broadcast_mixin import BroadcastMixin
class BasicMathMixin(BroadcastMixin):
pass
class ArithmeticMixin(BroadcastMixin):
pass
class ReductionMixin(BroadcastMixin):
@classmethod
def _common(cls, node, **kwargs):
attrs = copy.deepcopy(node.attrs)
axis = attrs.pop("axes", None)
if isinstance(axis, (list, tuple)) and len(axis) == 1:
axis = axis[0]
attrs["axis"] = axis
# https://github.com/onnx/onnx/issues/585
attrs["keepdims"] = attrs.pop("keepdims", 1) == 1
return [cls.make_tensor_from_onnx_node(node, attrs=attrs, **kwargs)]