add pt2tf tool
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
from onnx_tf.common import IS_PYTHON3
|
||||
|
||||
|
||||
def convert_tf(attr):
|
||||
return __convert_tf_attr_value(attr)
|
||||
|
||||
|
||||
def convert_onnx(attr):
|
||||
return __convert_onnx_attribute_proto(attr)
|
||||
|
||||
|
||||
def __convert_tf_attr_value(attr):
|
||||
""" convert Tensorflow AttrValue object to Python object
|
||||
"""
|
||||
if attr.HasField('list'):
|
||||
return __convert_tf_list_value(attr.list)
|
||||
if attr.HasField('s'):
|
||||
return attr.s
|
||||
elif attr.HasField('i'):
|
||||
return attr.i
|
||||
elif attr.HasField('f'):
|
||||
return attr.f
|
||||
elif attr.HasField('b'):
|
||||
return attr.b
|
||||
elif attr.HasField('type'):
|
||||
return attr.type
|
||||
elif attr.HasField('shape'):
|
||||
return attr.type
|
||||
elif attr.HasField('tensor'):
|
||||
return attr.tensor
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
|
||||
|
||||
|
||||
def __convert_tf_list_value(list_value):
|
||||
""" convert Tensorflow ListValue object to Python object
|
||||
"""
|
||||
if list_value.s:
|
||||
return list_value.s
|
||||
elif list_value.i:
|
||||
return list_value.i
|
||||
elif list_value.f:
|
||||
return list_value.f
|
||||
elif list_value.b:
|
||||
return list_value.b
|
||||
elif list_value.tensor:
|
||||
return list_value.tensor
|
||||
elif list_value.type:
|
||||
return list_value.type
|
||||
elif list_value.shape:
|
||||
return list_value.shape
|
||||
elif list_value.func:
|
||||
return list_value.func
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
|
||||
|
||||
|
||||
def __convert_onnx_attribute_proto(attr_proto):
|
||||
"""
|
||||
Convert an ONNX AttributeProto into an appropriate Python object
|
||||
for the type.
|
||||
NB: Tensor attribute gets returned as the straight proto.
|
||||
"""
|
||||
if attr_proto.HasField('f'):
|
||||
return attr_proto.f
|
||||
elif attr_proto.HasField('i'):
|
||||
return attr_proto.i
|
||||
elif attr_proto.HasField('s'):
|
||||
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
|
||||
elif attr_proto.HasField('t'):
|
||||
return attr_proto.t # this is a proto!
|
||||
elif attr_proto.HasField('g'):
|
||||
return attr_proto.g
|
||||
elif attr_proto.floats:
|
||||
return list(attr_proto.floats)
|
||||
elif attr_proto.ints:
|
||||
return list(attr_proto.ints)
|
||||
elif attr_proto.strings:
|
||||
str_list = list(attr_proto.strings)
|
||||
if IS_PYTHON3:
|
||||
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
|
||||
return str_list
|
||||
elif attr_proto.HasField('sparse_tensor'):
|
||||
return attr_proto.sparse_tensor
|
||||
else:
|
||||
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
|
||||
Reference in New Issue
Block a user