重命名 pt2tf 为 pt2pb

This commit is contained in:
zhutian
2020-10-14 08:55:07 +08:00
committed by Gitee
parent 324ab60a5d
commit 90ae190559
407 changed files with 0 additions and 0 deletions
@@ -0,0 +1,37 @@
from tensorflow.python.framework.tensor_util import MakeNdarray
from onnx_tf.common import data_type
# Keyed by old attribute names.
__tf_attr_translator = {
"_output_shapes": lambda x: list(map(lambda shape: get_tf_shape_as_list(shape.dim), x.list.shape)),
"shape": lambda x: get_tf_shape_as_list(x.shape.dim),
"T": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"dtype": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"component_types": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"value": lambda x: MakeNdarray(x.tensor),
"seed2": lambda x: float(x.i),
"seed": lambda x: float(x.i),
"keep_dims": lambda x: int(x.b),
"squeeze_dims": lambda x: list(x.list.i),
}
__onnx_attr_translator = {
"axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x),
"keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x),
}
def translate_tf(key, val):
return __tf_attr_translator.get(key, lambda x: x)(val)
def translate_onnx(key, val):
return __onnx_attr_translator.get(key, lambda x: x)(val)
def get_tf_shape_as_list(tf_shape_dim):
return list(map(lambda x: x.size, list(tf_shape_dim)))