重命名 pt2tf 为 pt2pb
This commit is contained in:
Executable
+143
@@ -0,0 +1,143 @@
|
||||
中文|[EN](README_EN.md)
|
||||
|
||||
# pytorch模型转onnx工具
|
||||
|
||||
## 功能
|
||||
当前ATC工具只支持pb和caffe模型转om模型。如果需要使用pytorch模型转om模型,可以将pytorch模型转为onnx格式,再转为pb。本工具提供pytorch模型转onnx,以及onnx转pb功能。
|
||||
|
||||
## 使用环境
|
||||
1. 安装Ubuntu18.04的服务器或者虚拟机;
|
||||
|
||||
2. 服务器或者虚拟机内存大于等于4G;
|
||||
|
||||
3. 已经安装pip3。如未安装,可以执行如下命令安装:
|
||||
|
||||
```
|
||||
sudo apt-get install python3-pip
|
||||
sudo pip3 install --upgrade pip -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
|
||||
```
|
||||
|
||||
4. 已经安装tensorflow、keras和pytorch
|
||||
|
||||
当前昇腾平台支持tensorflow 1.15,考虑后继pb模型转om,tensorflow版本推荐1.15及之前版本。tensorflow 1.15版本需要源码编译安装;使用pip命令直接安装时可以1.15之前的版本,以1.14为例:
|
||||
|
||||
```
|
||||
sudo pip3 install tensorflow==1.14.0 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
|
||||
```
|
||||
|
||||
对应的keras版本为2.2.5,安装命令:
|
||||
|
||||
```
|
||||
sudo pip3 install keras==2.2.5 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
|
||||
```
|
||||
|
||||
pytorch版本只要适配待转换的pytorch模型即可。pytorch的安装可以参考官网:https://pytorch.org/get-started/locally/
|
||||
|
||||
## 预置条件
|
||||
|
||||
1.pytorch模型文件。pytorch模型保存有两种,一种是保存有权重参数和网络结构,另外一种是指保存权重参数。本工具只支持保存权重参数的模型文件,模型保存接口示例:
|
||||
|
||||
```
|
||||
torch.save(my_resnet.state_dict(),"my_resnet.pth")
|
||||
```
|
||||
|
||||
2.模型实现代码。权重参数模型加载时,需要使用模型创建接口创建模型,作为模型加载的参数,所以需要模型实现代码。
|
||||
|
||||
## 工具获取
|
||||
|
||||
**方法1. 下载压缩包方式获取**
|
||||
|
||||
将 https://gitee.com/ascend/tools 仓中的脚本下载至服务器的任意目录。
|
||||
|
||||
例如存放路径为:$HOME/AscendProjects/tools。
|
||||
|
||||
**方法2. 命令行使用git命令方式获取**
|
||||
|
||||
在命令行中:$HOME/AscendProjects目录下执行以下命令下载代码。
|
||||
|
||||
git clone https://gitee.com/ascend/tools.git
|
||||
|
||||
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 安装工具依赖包
|
||||
|
||||
cd $HOME/AscendProjects/tools/pt2tf/
|
||||
sudo pip3 install -e onnx-tensorflow
|
||||
|
||||
### 2. pth模型文件转onnx
|
||||
1. 将pytorch模型和实现源码拷贝到pt2tf目录下
|
||||
|
||||
2. 使用vim或者文本工具打开pt2onnx.py,修改load_model函数。以resnet50模型为例,修改点如下:
|
||||
|
||||
(1)导入模型实现文件:
|
||||
|
||||
```
|
||||
#修改点1:导入模型代码.
|
||||
#例如:模型实现代码目录为./resnet50,网络实现在resnet.py的class ResNet50类
|
||||
from resnet50.resnet import ResNet50
|
||||
```
|
||||
|
||||
(2) 使用pytorch实例化模型对象
|
||||
|
||||
```
|
||||
#修改点2:创建模型对象
|
||||
model = ResNet50()
|
||||
```
|
||||
|
||||
(3)加载训练好的模型
|
||||
|
||||
```
|
||||
#修改点3:训练好的模型路径
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
```
|
||||
|
||||
综上,完整的load_weight_model代码:
|
||||
|
||||
def load_model(model_path, input_shape):
|
||||
if not os.path.exists(model_path):
|
||||
print("The pytorch model is not exist")
|
||||
return None
|
||||
from resnet50.resnet import ResNet50
|
||||
|
||||
model = ResNet50()
|
||||
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
||||
return model
|
||||
|
||||
3. 执行转换脚本
|
||||
|
||||
```
|
||||
python3 pt2onnx.py --model_path="./resnet50/models/resnet50_best.pth" --input_shape=1 3 224 224
|
||||
```
|
||||
|
||||
参数说明:
|
||||
|
||||
--model_type: 模型类别,默认值为0,表示完备信息模型;1: 仅包含权重参数的模型
|
||||
|
||||
--model_path: pytorch模型存放路径
|
||||
|
||||
--input_shape: 模型输入 shape
|
||||
|
||||
如果执行成功,将在pytorch目录下生成onnx文件,文件名和pytorch模型文件名一致,例如./resnet50/models/resnet50_best.onnx
|
||||
|
||||
### 3.使用 onnx-tf工具将onnx转为 pb
|
||||
|
||||
执行命令
|
||||
|
||||
onnx-tf convert -i ./resnet50/models/resnet50_best.onnx -o ./resnet50/resnet50_best.pb
|
||||
|
||||
参数说明:
|
||||
|
||||
-i:onnx文件路径
|
||||
|
||||
-o: 输出的pb模型文件
|
||||
|
||||
onnx-tf convert的参数说明详见帮助:
|
||||
|
||||
```
|
||||
onnx-tf convert --help
|
||||
```
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
os: linux
|
||||
arch: amd64
|
||||
dist: bionic
|
||||
language: python
|
||||
python: 3.6
|
||||
cache: pip
|
||||
|
||||
# Required
|
||||
# - ONNX 1.6 (Opset 11), TensorFlow 1.15.0, Sep 2019
|
||||
# - ONNX 1.6 (Opset 11), TensorFlow 1.15.3, May 2020
|
||||
# Optional
|
||||
# - ONNX 1.7 (Opset 12), TensorFlow 1.15.3, May 2020
|
||||
|
||||
env:
|
||||
- ONNX_PIP=onnx==1.6.0 TF_PIP=tensorflow==1.15.0
|
||||
- ONNX_PIP=onnx==1.6.0 TF_PIP=tensorflow==1.15.3
|
||||
- ONNX_PIP=onnx==1.7.0 TF_PIP=tensorflow==1.15.3
|
||||
|
||||
jobs:
|
||||
fast_finish: true
|
||||
# Be aware when updating the dependency versions.
|
||||
# Envs below must match *exactly* an env from above,
|
||||
# otherwise an env failure will fail the overall build.
|
||||
allow_failures:
|
||||
- env: ONNX_PIP=onnx==1.7.0 TF_PIP=tensorflow==1.15.3
|
||||
|
||||
before_install: pip install -U setuptools
|
||||
install: pip install $ONNX_PIP $TF_PIP
|
||||
before_script: pip install -e .
|
||||
script: python -m unittest discover test -v
|
||||
@@ -0,0 +1,213 @@
|
||||
ONNX-Tensorflow Converter
|
||||
|
||||
Copyright (c) International Business Machines Corporation 2017, 2018
|
||||
Copyright (c) LeapMind Inc. 2018
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and
|
||||
distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the
|
||||
copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other
|
||||
entities that control, are controlled by, or are under common control with
|
||||
that entity. For the purposes of this definition, "control" means (i) the
|
||||
power, direct or indirect, to cause the direction or management of such
|
||||
entity, whether by contract or otherwise, or (ii) ownership of
|
||||
fifty percent (50%) or more of the outstanding shares, or (iii) beneficial
|
||||
ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
||||
permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation source,
|
||||
and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation
|
||||
or translation of a Source form, including but not limited to compiled
|
||||
object code, generated documentation, and conversions to
|
||||
other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object
|
||||
form, made available under the License, as indicated by a copyright notice
|
||||
that is included in or attached to the work (an example is provided in the
|
||||
Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form,
|
||||
that is based on (or derived from) the Work and for which the editorial
|
||||
revisions, annotations, elaborations, or other modifications represent,
|
||||
as a whole, an original work of authorship. For the purposes of this
|
||||
License, Derivative Works shall not include works that remain separable
|
||||
from, or merely link (or bind by name) to the interfaces of, the Work and
|
||||
Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original
|
||||
version of the Work and any modifications or additions to that Work or
|
||||
Derivative Works thereof, that is intentionally submitted to Licensor for
|
||||
inclusion in the Work by the copyright owner or by an individual or
|
||||
Legal Entity authorized to submit on behalf of the copyright owner.
|
||||
For the purposes of this definition, "submitted" means any form of
|
||||
electronic, verbal, or written communication sent to the Licensor or its
|
||||
representatives, including but not limited to communication on electronic
|
||||
mailing lists, source code control systems, and issue tracking systems
|
||||
that are managed by, or on behalf of, the Licensor for the purpose of
|
||||
discussing and improving the Work, but excluding communication that is
|
||||
conspicuously marked or otherwise designated in writing by the copyright
|
||||
owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on
|
||||
behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License.
|
||||
|
||||
Subject to the terms and conditions of this License, each Contributor
|
||||
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable copyright license to reproduce, prepare
|
||||
Derivative Works of, publicly display, publicly perform, sublicense,
|
||||
and distribute the Work and such Derivative Works in
|
||||
Source or Object form.
|
||||
|
||||
3. Grant of Patent License.
|
||||
|
||||
Subject to the terms and conditions of this License, each Contributor
|
||||
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable (except as stated in this section) patent
|
||||
license to make, have made, use, offer to sell, sell, import, and
|
||||
otherwise transfer the Work, where such license applies only to those
|
||||
patent claims licensable by such Contributor that are necessarily
|
||||
infringed by their Contribution(s) alone or by combination of their
|
||||
Contribution(s) with the Work to which such Contribution(s) was submitted.
|
||||
If You institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
|
||||
Contribution incorporated within the Work constitutes direct or
|
||||
contributory patent infringement, then any patent licenses granted to
|
||||
You under this License for that Work shall terminate as of the date such
|
||||
litigation is filed.
|
||||
|
||||
4. Redistribution.
|
||||
|
||||
You may reproduce and distribute copies of the Work or Derivative Works
|
||||
thereof in any medium, with or without modifications, and in Source or
|
||||
Object form, provided that You meet the following conditions:
|
||||
|
||||
1. You must give any other recipients of the Work or Derivative Works a
|
||||
copy of this License; and
|
||||
|
||||
2. You must cause any modified files to carry prominent notices stating
|
||||
that You changed the files; and
|
||||
|
||||
3. You must retain, in the Source form of any Derivative Works that You
|
||||
distribute, all copyright, patent, trademark, and attribution notices from
|
||||
the Source form of the Work, excluding those notices that do not pertain
|
||||
to any part of the Derivative Works; and
|
||||
|
||||
4. If the Work includes a "NOTICE" text file as part of its distribution,
|
||||
then any Derivative Works that You distribute must include a readable copy
|
||||
of the attribution notices contained within such NOTICE file, excluding
|
||||
those notices that do not pertain to any part of the Derivative Works,
|
||||
in at least one of the following places: within a NOTICE text file
|
||||
distributed as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or, within a
|
||||
display generated by the Derivative Works, if and wherever such
|
||||
third-party notices normally appear. The contents of the NOTICE file are
|
||||
for informational purposes only and do not modify the License.
|
||||
You may add Your own attribution notices within Derivative Works that You
|
||||
distribute, alongside or as an addendum to the NOTICE text from the Work,
|
||||
provided that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may
|
||||
provide additional or different license terms and conditions for use,
|
||||
reproduction, or distribution of Your modifications, or for any such
|
||||
Derivative Works as a whole, provided Your use, reproduction, and
|
||||
distribution of the Work otherwise complies with the conditions
|
||||
stated in this License.
|
||||
|
||||
5. Submission of Contributions.
|
||||
|
||||
Unless You explicitly state otherwise, any Contribution intentionally
|
||||
submitted for inclusion in the Work by You to the Licensor shall be under
|
||||
the terms and conditions of this License, without any additional
|
||||
terms or conditions. Notwithstanding the above, nothing herein shall
|
||||
supersede or modify the terms of any separate license agreement you may
|
||||
have executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks.
|
||||
|
||||
This License does not grant permission to use the trade names, trademarks,
|
||||
service marks, or product names of the Licensor, except as required for
|
||||
reasonable and customary use in describing the origin of the Work and
|
||||
reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty.
|
||||
|
||||
Unless required by applicable law or agreed to in writing, Licensor
|
||||
provides the Work (and each Contributor provides its Contributions)
|
||||
on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
either express or implied, including, without limitation, any warranties
|
||||
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS
|
||||
FOR A PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any risks
|
||||
associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability.
|
||||
|
||||
In no event and under no legal theory, whether in tort
|
||||
(including negligence), contract, or otherwise, unless required by
|
||||
applicable law (such as deliberate and grossly negligent acts) or agreed
|
||||
to in writing, shall any Contributor be liable to You for damages,
|
||||
including any direct, indirect, special, incidental, or consequential
|
||||
damages of any character arising as a result of this License or out of
|
||||
the use or inability to use the Work (including but not limited to damages
|
||||
for loss of goodwill, work stoppage, computer failure or malfunction,
|
||||
or any and all other commercial damages or losses), even if such
|
||||
Contributor has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability.
|
||||
|
||||
While redistributing the Work or Derivative Works thereof, You may choose
|
||||
to offer, and charge a fee for, acceptance of support, warranty,
|
||||
indemnity, or other liability obligations and/or rights consistent with
|
||||
this License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf of any
|
||||
other Contributor, and only if You agree to indemnify, defend, and hold
|
||||
each Contributor harmless for any liability incurred by, or claims
|
||||
asserted against, such Contributor by reason of your accepting any such
|
||||
warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate
|
||||
notice, with the fields enclosed by brackets "[]" replaced with your own
|
||||
identifying information. (Don't include the brackets!) The text should be
|
||||
enclosed in the appropriate comment syntax for the file format. We also
|
||||
recommend that a file or class name and description of purpose be included
|
||||
on the same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2017 ONNX-TF Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
or implied. See the License for the specific language governing
|
||||
permissions and limitations under the License.
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
include LICENSE
|
||||
include VERSION_NUMBER
|
||||
include ONNX_VERSION_NUMBER
|
||||
@@ -0,0 +1 @@
|
||||
1.6.0
|
||||
@@ -0,0 +1,101 @@
|
||||
# Tensorflow Backend for ONNX
|
||||
[](https://travis-ci.org/onnx/onnx-tensorflow)
|
||||
|
||||
## To convert models from ONNX to Tensorflow:
|
||||
|
||||
### Use CLI:
|
||||
|
||||
[Command Line Interface Documentation](https://github.com/onnx/onnx-tensorflow/blob/master/doc/CLI.md)
|
||||
|
||||
From ONNX to Tensorflow: `onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
|
||||
|
||||
### Convert programmatically:
|
||||
|
||||
[From ONNX to Tensorflow](https://github.com/onnx/onnx-tensorflow/blob/master/example/onnx_to_tf.py)
|
||||
|
||||
### Migrating from `onnx-tf` to `tf-onnx`:
|
||||
We have joined force with Microsoft to co-develop ONNX Tensorflow frontend.
|
||||
For current onnx-tf frontend users, please migrate to use tf-onnx (https://github.com/onnx/tensorflow-onnx) where our code had been merged into.
|
||||
|
||||
## ONNX model inference with Tensorflow backend:
|
||||
```
|
||||
import onnx
|
||||
from onnx_tf.backend import prepare
|
||||
|
||||
onnx_model = onnx.load("input_path") # load onnx model
|
||||
output = prepare(onnx_model).run(input) # run the loaded model
|
||||
```
|
||||
|
||||
## More tutorials:
|
||||
[Running an ONNX model using Tensorflow](https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowImport.ipynb)
|
||||
|
||||
## Production Installation:
|
||||
ONNX-TF requires ONNX (Open Neural Network Exchange) as an external dependency, for any issues related to ONNX installation, we refer our users to [ONNX project repository](https://github.com/onnx/onnx) for documentation and help. Notably, please ensure that protoc is available if you plan to install ONNX via pip.
|
||||
|
||||
The specific ONNX release version that we support in the master branch of ONNX-TF can be found [here](https://github.com/onnx/onnx-tensorflow/blob/master/ONNX_VERSION_NUMBER). This information about ONNX version requirement is automatically encoded in `setup.py`, therefore users needn't worry about ONNX version requirement when installing ONNX-TF.
|
||||
|
||||
Because users often have their own preferences for which variant of Tensorflow to install (i.e., a GPU version instead of a CPU version), we do not explicitly require tensorflow in the installation script. It is therefore users' responsibility to ensure that the proper variant of Tensorflow is available to ONNX-TF. Moreover, we require Tensorflow version == 1.15.0.
|
||||
|
||||
To install the latest version of ONNX-TF v1.6.0
|
||||
- Run `git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow`.
|
||||
- Run `git checkout v1.6.0-tf-1.15`.
|
||||
- Run `pip install -e .`.
|
||||
|
||||
## Development:
|
||||
|
||||
### Coverage Status:
|
||||
[ONNX-Tensorflow Op Coverage Status](https://github.com/onnx/onnx-tensorflow/blob/tf-1.x/doc/support_status.md)
|
||||
|
||||
### API:
|
||||
[ONNX-Tensorflow API](https://github.com/onnx/onnx-tensorflow/blob/tf-1.x/doc/API.md)
|
||||
|
||||
### Installation:
|
||||
- Install ONNX master branch from source.
|
||||
- Install Tensorflow 1.15.0. (For Tensorflow 2.x support please refer [here](https://github.com/onnx/onnx-tensorflow/blob/master/README.md/).)
|
||||
- Run `git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow`.
|
||||
- Run `git checkout tf-1.x`.
|
||||
- Run `pip install -e .`.
|
||||
|
||||
### Folder Structure:
|
||||
- __onnx_tf__ main source code file.
|
||||
- __test__ test files.
|
||||
|
||||
### Code Standard:
|
||||
- Format code:
|
||||
```
|
||||
pip install yapf
|
||||
yapf -rip --style="{based_on_style: google, indent_width: 2}" $FilePath$
|
||||
```
|
||||
- Install pylint:
|
||||
```
|
||||
pip install pylint
|
||||
wget -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc
|
||||
```
|
||||
- Check format:
|
||||
```
|
||||
pylint --rcfile=/tmp/pylintrc myfile.py
|
||||
```
|
||||
|
||||
### Documentation Standard:
|
||||
http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
|
||||
|
||||
### To test:
|
||||
To perfom unit tests, run `python -m unittest discover test`.
|
||||
Testing requires significant hardware resources, but nonetheless, we highly recommend that users run through the complete test suite before deploying onnx-tf. The complete test suite typically takes between 15 and 45 minutes to complete, depending on hardware configurations.
|
||||
|
||||
PS. Please ensure your code is backward compatible with older version of ONNX. You can easily test it by running the following [docker container](https://hub.docker.com/r/winnietsang/onnx-tensorflow) with your code. If you don't have Docker installed yet, please follow this link to install [Docker](https://docs.docker.com/install/) on your environment.
|
||||
```
|
||||
sudo docker pull winnietsang/onnx-tensorflow:onnx1.7.0-tf1.15
|
||||
sudo docker run -it --name=YOUR-CONTAINER-NAME winnietsang/onnx-tensorflow:onnx1.7.0-tf1.15 /bin/bash
|
||||
git clone https://github.com/YOUR-USERNAME/onnx-tensorflow.git
|
||||
cd onnx-tensorflow
|
||||
git checkout -b YOUR-BRANCH --track remotes/origin/YOUR-BRANCH
|
||||
pip3 install -e .
|
||||
python3 -m unittest discover test
|
||||
```
|
||||
|
||||
#### Test Help:
|
||||
https://docs.python.org/2/library/unittest.html
|
||||
|
||||
#### Note:
|
||||
Branch tf-1.x is for users who cannot upgrade to Tensorflow 2.x yet. This branch will only support up to ONNX OpSet 12 operators. If any user needs to use operators in OpSet 13 or above, please upgrade Tensoflow to 2.x and use the master branch of this repo. By January 1st, 2021 this branch will switch to maintenance mode only, no new development will be added into this branch from then on.
|
||||
@@ -0,0 +1 @@
|
||||
1.6.0
|
||||
@@ -0,0 +1,10 @@
|
||||
## Released Versions
|
||||
|
||||
These are the supported ONNX versions and TensorFlow versions for each release.
|
||||
|
||||
ONNX-TensorFlow version|ONNX version|TensorFlow version
|
||||
-----------------------|------------|------------------
|
||||
1.2.1|1.1.2 (opset 4)|1.5
|
||||
1.3.0|1.3.0 (opset 8)|1.13.1
|
||||
1.5.0|1.5.0 (opset 10)|1.15.0
|
||||
1.6.0|1.6.0 (opset 11)|1.15.0
|
||||
@@ -0,0 +1,62 @@
|
||||
ONNX-Tensorflow API
|
||||
======
|
||||
|
||||
#### `onnx_tf.backend.prepare`
|
||||
|
||||
<details>
|
||||
<summary>Prepare an ONNX model for Tensorflow Backend.
|
||||
|
||||
</summary>
|
||||
This function converts an ONNX model to an internel representation
|
||||
of the computational graph called TensorflowRep and returns
|
||||
the converted representation.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
_params_:
|
||||
|
||||
`model` : The ONNX model to be converted.
|
||||
|
||||
|
||||
`device` : The device to execute this model on.
|
||||
|
||||
|
||||
`strict` : Whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
||||
Changing to False is strongly discouraged.
|
||||
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
|
||||
|
||||
|
||||
`logging_level` : The logging level, default is INFO. Change it to DEBUG
|
||||
to see more conversion details or to WARNING to see less
|
||||
|
||||
|
||||
_returns_:
|
||||
|
||||
A TensorflowRep class object representing the ONNX model
|
||||
|
||||
#### `onnx_tf.backend_rep.TensorflowRep.export_graph`
|
||||
|
||||
<details>
|
||||
<summary>Export backend representation to a Tensorflow proto file.
|
||||
|
||||
</summary>
|
||||
This function obtains the graph proto corresponding to the ONNX
|
||||
model associated with the backend representation and serializes
|
||||
to a protobuf file.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
_params_:
|
||||
|
||||
`path` : The path to the output TF protobuf file.
|
||||
|
||||
|
||||
_returns_:
|
||||
|
||||
none.
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
ONNX-Tensorflow Command Line Interface
|
||||
======
|
||||
|
||||
## Available commands:
|
||||
- convert
|
||||
|
||||
More information: `onnx-tf -h`
|
||||
```
|
||||
usage: onnx-tf [-h] {convert}
|
||||
|
||||
ONNX-Tensorflow Command Line Interface
|
||||
|
||||
positional arguments:
|
||||
{convert} Available commands.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
```
|
||||
|
||||
## Usage:
|
||||
|
||||
### Convert:
|
||||
|
||||
#### From ONNX to Tensorflow:
|
||||
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
|
||||
|
||||
More information: `onnx-tf convert -h`
|
||||
```
|
||||
usage: onnx-tf [-h] --infile INFILE --outfile OUTFILE [--device DEVICE]
|
||||
[--strict STRICT] [--logging_level LOGGING_LEVEL]
|
||||
|
||||
This is the converter for converting protocol buffer between tf and onnx.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--infile INFILE, -i INFILE
|
||||
Input file path.
|
||||
--outfile OUTFILE, -o OUTFILE
|
||||
Output file path.
|
||||
|
||||
backend arguments (onnx -> tf):
|
||||
--device DEVICE The device to execute this model on. (from
|
||||
onnx_tf.backend.prepare)
|
||||
--strict STRICT Whether to enforce semantic equivalence between the
|
||||
original model and the converted tensorflow model,
|
||||
defaults to True (yes, enforce semantic equivalence).
|
||||
Changing to False is strongly discouraged. Currently,
|
||||
the strict flag only affects the behavior of MaxPool
|
||||
and AveragePool ops. (from onnx_tf.backend.prepare)
|
||||
--logging_level LOGGING_LEVEL
|
||||
The logging level, default is INFO. Change it to DEBUG
|
||||
to see more conversion details or to WARNING to see
|
||||
less (from onnx_tf.backend.prepare)
|
||||
```
|
||||
@@ -0,0 +1,22 @@
|
||||
ONNX-Tensorflow Command Line Interface
|
||||
======
|
||||
|
||||
## Available commands:
|
||||
- convert
|
||||
|
||||
More information: `onnx-tf -h`
|
||||
```
|
||||
{onnx-tf -h}
|
||||
```
|
||||
|
||||
## Usage:
|
||||
|
||||
### Convert:
|
||||
|
||||
#### From ONNX to Tensorflow:
|
||||
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
|
||||
|
||||
More information: `onnx-tf convert -h`
|
||||
```
|
||||
{onnx-tf convert -h}
|
||||
```
|
||||
@@ -0,0 +1,22 @@
|
||||
How to implement new op
|
||||
======
|
||||
|
||||
When you get `{} op is not implemented`, you can follow next steps to implement.
|
||||
Customize op can also be implemented in similar way.
|
||||
|
||||
### Backend
|
||||
|
||||
1. Verify the latest master version of ONNX is installed on your environment
|
||||
2. Find specification from [onnx/Operators](https://github.com/onnx/onnx/blob/master/docs/Operators.md).
|
||||
3. Implement the handler. All inputs and attrs could get from step 2.
|
||||
```
|
||||
- add handler to /onnx_tf/handlers/backend/
|
||||
- in the new handler define a classmethod called version_{version}
|
||||
|
||||
* version is the number of since version, which can get from operator's specification
|
||||
```
|
||||
4. From within the `onnx_tf` directory, run `gen_opset.py`.
|
||||
5. From within the `onnx_tf` directory, run `gen_status.py`.
|
||||
6. From within the `onnx_tf` directory, run `gen_doc.py` if there is any update to CLI or API.
|
||||
7. Verify the operator's test cases in `test/backend/test_onnx_backend.py` all pass.
|
||||
8. Add any additional test cases to `test/backend/test_node.py`.
|
||||
@@ -0,0 +1,209 @@
|
||||
# ONNX-Tensorflow Support Status
|
||||
|||
|
||||
|-:|:-|
|
||||
|ONNX-Tensorflow Version|Master ( commit id: 8fea59a976e2d65eab2ab021864e2cab038bb7d5 )|
|
||||
|ONNX Version|v1.7.0|
|
||||
|Tensorflow Version|v1.15.0|
|
||||
|
||||
Notes:
|
||||
* Values that are new or updated from a previous opset version are in bold.
|
||||
* -: not defined in corresponding ONNX opset version
|
||||
* \*: the operator is deprecated
|
||||
* :small_red_triangle:: not supported yet
|
||||
* :small_orange_diamond:: partially supported
|
||||
* the rest are all supported
|
||||
|
||||
|||||||||||||||
|
||||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|
||||
|**ONNX Operator**|**Opset 1**|**Opset 2**|**Opset 3**|**Opset 4**|**Opset 5**|**Opset 6**|**Opset 7**|**Opset 8**|**Opset 9**|**Opset 10**|**Opset 11**|**Opset 12**|**ONNX Operator**|
|
||||
|Abs|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Abs|
|
||||
|Acos|-|-|-|-|-|-|**7**|7|7|7|7|7|Acos|
|
||||
|Acosh|-|-|-|-|-|-|-|-|**9**|9|9|9|Acosh|
|
||||
|Add|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|Add|
|
||||
|And|**1**|1|1|1|1|1|**7**|7|7|7|7|7|And|
|
||||
|ArgMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|ArgMax|
|
||||
|ArgMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|ArgMin|
|
||||
|Asin|-|-|-|-|-|-|**7**|7|7|7|7|7|Asin|
|
||||
|Asinh|-|-|-|-|-|-|-|-|**9**|9|9|9|Asinh|
|
||||
|Atan|-|-|-|-|-|-|**7**|7|7|7|7|7|Atan|
|
||||
|Atanh|-|-|-|-|-|-|-|-|**9**|9|9|9|Atanh|
|
||||
|AveragePool|**1**|1|1|1|1|1|**7**|7|7|**10**|**11**|11|AveragePool|
|
||||
|BatchNormalization|**1**|1|1|1|1|**6**|**7**|7|**9**|9|9|9|BatchNormalization|
|
||||
|BitShift|-|-|-|-|-|-|-|-|-|-|**11**|11|BitShift|
|
||||
|Cast|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**9**:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|Cast|
|
||||
|Ceil|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Ceil|
|
||||
|Celu|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|Celu|
|
||||
|Clip|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**11**:small_orange_diamond:|**12**:small_orange_diamond:|Clip|
|
||||
|Compress|-|-|-|-|-|-|-|-|**9**|9|**11**|11|Compress|
|
||||
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|Concat|
|
||||
|ConcatFromSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_red_triangle:|11:small_red_triangle:|ConcatFromSequence|
|
||||
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|Constant|
|
||||
|ConstantOfShape|-|-|-|-|-|-|-|-|**9**|9|9|9|ConstantOfShape|
|
||||
|Conv|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Conv|
|
||||
|ConvInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|ConvInteger|
|
||||
|ConvTranspose|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**11**:small_orange_diamond:|11:small_orange_diamond:|ConvTranspose|
|
||||
|Cos|-|-|-|-|-|-|**7**|7|7|7|7|7|Cos|
|
||||
|Cosh|-|-|-|-|-|-|-|-|**9**|9|9|9|Cosh|
|
||||
|CumSum|-|-|-|-|-|-|-|-|-|-|**11**:small_orange_diamond:|11:small_orange_diamond:|CumSum|
|
||||
|DepthToSpace|**1**|1|1|1|1|1|1|1|1|1|**11**|11|DepthToSpace|
|
||||
|DequantizeLinear|-|-|-|-|-|-|-|-|-|**10**|10|10|DequantizeLinear|
|
||||
|Det|-|-|-|-|-|-|-|-|-|-|**11**|11|Det|
|
||||
|Div|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|Div|
|
||||
|Dropout|**1**|1|1|1|1|**6**|**7**|7|7|**10**|10|**12**:small_red_triangle:|Dropout|
|
||||
|DynamicQuantizeLinear|-|-|-|-|-|-|-|-|-|-|**11**|11|DynamicQuantizeLinear|
|
||||
|Einsum|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|Einsum|
|
||||
|Elu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Elu|
|
||||
|Equal|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**11**:small_orange_diamond:|11:small_orange_diamond:|Equal|
|
||||
|Erf|-|-|-|-|-|-|-|-|**9**|9|9|9|Erf|
|
||||
|Exp|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Exp|
|
||||
|Expand|-|-|-|-|-|-|-|**8**|8|8|8|8|Expand|
|
||||
|EyeLike|-|-|-|-|-|-|-|-|**9**|9|9|9|EyeLike|
|
||||
|Flatten|**1**|1|1|1|1|1|1|1|**9**|9|**11**|11|Flatten|
|
||||
|Floor|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Floor|
|
||||
|GRU|**1**:small_orange_diamond:|1:small_orange_diamond:|**3**:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|GRU|
|
||||
|Gather|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Gather|
|
||||
|GatherElements|-|-|-|-|-|-|-|-|-|-|**11**|11|GatherElements|
|
||||
|GatherND|-|-|-|-|-|-|-|-|-|-|**11**|**12**:small_red_triangle:|GatherND|
|
||||
|Gemm|**1**|1|1|1|1|**6**|**7**|7|**9**|9|**11**|11|Gemm|
|
||||
|GlobalAveragePool|**1**|1|1|1|1|1|1|1|1|1|1|1|GlobalAveragePool|
|
||||
|GlobalLpPool|**1**|**2**|2|2|2|2|2|2|2|2|2|2|GlobalLpPool|
|
||||
|GlobalMaxPool|**1**|1|1|1|1|1|1|1|1|1|1|1|GlobalMaxPool|
|
||||
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|Greater|
|
||||
|GreaterOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|GreaterOrEqual|
|
||||
|HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|6|HardSigmoid|
|
||||
|Hardmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Hardmax|
|
||||
|Identity|**1**|1|1|1|1|1|1|1|1|1|1|1|Identity|
|
||||
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|If|
|
||||
|InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|6|6|InstanceNormalization|
|
||||
|IsInf|-|-|-|-|-|-|-|-|-|**10**|10|10|IsInf|
|
||||
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|IsNaN|
|
||||
|LRN|**1**|1|1|1|1|1|1|1|1|1|1|1|LRN|
|
||||
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|LSTM|
|
||||
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|LeakyRelu|
|
||||
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|Less|
|
||||
|LessOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|LessOrEqual|
|
||||
|Log|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Log|
|
||||
|LogSoftmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|LogSoftmax|
|
||||
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Loop|
|
||||
|LpNormalization|**1**|1|1|1|1|1|1|1|1|1|1|1|LpNormalization|
|
||||
|LpPool|**1**|**2**|2|2|2|2|2|2|2|2|**11**|11|LpPool|
|
||||
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|MatMul|
|
||||
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|MatMulInteger|
|
||||
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|8|**12**:small_red_triangle:|Max|
|
||||
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|**11**:small_orange_diamond:|**12**:small_orange_diamond:|MaxPool|
|
||||
|MaxRoiPool|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|MaxRoiPool|
|
||||
|MaxUnpool|-|-|-|-|-|-|-|-|**9**|9|**11**|11|MaxUnpool|
|
||||
|Mean|**1**|1|1|1|1|**6**|6|**8**|8|8|8|8|Mean|
|
||||
|MeanVarianceNormalization|-|-|-|-|-|-|-|-|**9**|9|9|9|MeanVarianceNormalization|
|
||||
|Min|**1**|1|1|1|1|**6**|6|**8**|8|8|8|**12**:small_red_triangle:|Min|
|
||||
|Mod|-|-|-|-|-|-|-|-|-|**10**:small_orange_diamond:|10:small_orange_diamond:|10:small_orange_diamond:|Mod|
|
||||
|Mul|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|Mul|
|
||||
|Multinomial|-|-|-|-|-|-|**7**:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|Multinomial|
|
||||
|Neg|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Neg|
|
||||
|NegativeLogLikelihoodLoss|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|NegativeLogLikelihoodLoss|
|
||||
|NonMaxSuppression|-|-|-|-|-|-|-|-|-|**10**|**11**|11|NonMaxSuppression|
|
||||
|NonZero|-|-|-|-|-|-|-|-|**9**|9|9|9|NonZero|
|
||||
|Not|**1**|1|1|1|1|1|1|1|1|1|1|1|Not|
|
||||
|OneHot|-|-|-|-|-|-|-|-|**9**:small_orange_diamond:|9:small_orange_diamond:|**11**:small_orange_diamond:|11:small_orange_diamond:|OneHot|
|
||||
|Or|**1**|1|1|1|1|1|**7**|7|7|7|7|7|Or|
|
||||
|PRelu|**1**|1|1|1|1|**6**|**7**|7|**9**|9|9|9|PRelu|
|
||||
|Pad|**1**|**2**|2|2|2|2|2|2|2|2|**11**|11|Pad|
|
||||
|Pow|**1**|1|1|1|1|1|**7**|7|7|7|7|**12**:small_red_triangle:|Pow|
|
||||
|QLinearConv|-|-|-|-|-|-|-|-|-|**10**|10|10|QLinearConv|
|
||||
|QLinearMatMul|-|-|-|-|-|-|-|-|-|**10**|10|10|QLinearMatMul|
|
||||
|QuantizeLinear|-|-|-|-|-|-|-|-|-|**10**|10|10|QuantizeLinear|
|
||||
|RNN|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|RNN|
|
||||
|RandomNormal|**1**|1|1|1|1|1|1|1|1|1|1|1|RandomNormal|
|
||||
|RandomNormalLike|**1**|1|1|1|1|1|1|1|1|1|1|1|RandomNormalLike|
|
||||
|RandomUniform|**1**|1|1|1|1|1|1|1|1|1|1|1|RandomUniform|
|
||||
|RandomUniformLike|**1**|1|1|1|1|1|1|1|1|1|1|1|RandomUniformLike|
|
||||
|Range|-|-|-|-|-|-|-|-|-|-|**11**|11|Range|
|
||||
|Reciprocal|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Reciprocal|
|
||||
|ReduceL1|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceL1|
|
||||
|ReduceL2|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceL2|
|
||||
|ReduceLogSum|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceLogSum|
|
||||
|ReduceLogSumExp|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceLogSumExp|
|
||||
|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|ReduceMax|
|
||||
|ReduceMean|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceMean|
|
||||
|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|ReduceMin|
|
||||
|ReduceProd|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceProd|
|
||||
|ReduceSum|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceSum|
|
||||
|ReduceSumSquare|**1**|1|1|1|1|1|1|1|1|1|**11**|11|ReduceSumSquare|
|
||||
|Relu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Relu|
|
||||
|Reshape|**1**|1|1|1|**5**|5|5|5|5|5|5|5|Reshape|
|
||||
|Resize|-|-|-|-|-|-|-|-|-|**10**:small_orange_diamond:|**11**:small_orange_diamond:|11:small_orange_diamond:|Resize|
|
||||
|ReverseSequence|-|-|-|-|-|-|-|-|-|**10**|10|10|ReverseSequence|
|
||||
|RoiAlign|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|RoiAlign|
|
||||
|Round|-|-|-|-|-|-|-|-|-|-|**11**|11|Round|
|
||||
|Scan|-|-|-|-|-|-|-|**8**|**9**|9|**11**|11|Scan|
|
||||
|Scatter|-|-|-|-|-|-|-|-|**9**|9|**11**\*|11\*|Scatter|
|
||||
|ScatterElements|-|-|-|-|-|-|-|-|-|-|**11**|11|ScatterElements|
|
||||
|ScatterND|-|-|-|-|-|-|-|-|-|-|**11**|11|ScatterND|
|
||||
|Selu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Selu|
|
||||
|SequenceAt|-|-|-|-|-|-|-|-|-|-|**11**|11|SequenceAt|
|
||||
|SequenceConstruct|-|-|-|-|-|-|-|-|-|-|**11**|11|SequenceConstruct|
|
||||
|SequenceEmpty|-|-|-|-|-|-|-|-|-|-|**11**|11|SequenceEmpty|
|
||||
|SequenceErase|-|-|-|-|-|-|-|-|-|-|**11**|11|SequenceErase|
|
||||
|SequenceInsert|-|-|-|-|-|-|-|-|-|-|**11**|11|SequenceInsert|
|
||||
|SequenceLength|-|-|-|-|-|-|-|-|-|-|**11**|11|SequenceLength|
|
||||
|Shape|**1**|1|1|1|1|1|1|1|1|1|1|1|Shape|
|
||||
|Shrink|-|-|-|-|-|-|-|-|**9**|9|9|9|Shrink|
|
||||
|Sigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Sigmoid|
|
||||
|Sign|-|-|-|-|-|-|-|-|**9**|9|9|9|Sign|
|
||||
|Sin|-|-|-|-|-|-|**7**|7|7|7|7|7|Sin|
|
||||
|Sinh|-|-|-|-|-|-|-|-|**9**|9|9|9|Sinh|
|
||||
|Size|**1**|1|1|1|1|1|1|1|1|1|1|1|Size|
|
||||
|Slice|**1**|1|1|1|1|1|1|1|1|**10**|**11**|11|Slice|
|
||||
|Softmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Softmax|
|
||||
|SoftmaxCrossEntropyLoss|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|SoftmaxCrossEntropyLoss|
|
||||
|Softplus|**1**|1|1|1|1|1|1|1|1|1|1|1|Softplus|
|
||||
|Softsign|**1**|1|1|1|1|1|1|1|1|1|1|1|Softsign|
|
||||
|SpaceToDepth|**1**|1|1|1|1|1|1|1|1|1|1|1|SpaceToDepth|
|
||||
|Split|**1**|**2**|2|2|2|2|2|2|2|2|**11**|11|Split|
|
||||
|SplitToSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_red_triangle:|11:small_red_triangle:|SplitToSequence|
|
||||
|Sqrt|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Sqrt|
|
||||
|Squeeze|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Squeeze|
|
||||
|StringNormalizer|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|StringNormalizer|
|
||||
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|Sub|
|
||||
|Sum|**1**|1|1|1|1|**6**|6|**8**|8|8|8|8|Sum|
|
||||
|Tan|-|-|-|-|-|-|**7**|7|7|7|7|7|Tan|
|
||||
|Tanh|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Tanh|
|
||||
|TfIdfVectorizer|-|-|-|-|-|-|-|-|**9**|9|9|9|TfIdfVectorizer|
|
||||
|ThresholdedRelu|-|-|-|-|-|-|-|-|-|**10**|10|10|ThresholdedRelu|
|
||||
|Tile|**1**|1|1|1|1|**6**|6|6|6|6|6|6|Tile|
|
||||
|TopK|**1**|1|1|1|1|1|1|1|1|**10**|**11**|11|TopK|
|
||||
|Transpose|**1**|1|1|1|1|1|1|1|1|1|1|1|Transpose|
|
||||
|Unique|-|-|-|-|-|-|-|-|-|-|**11**:small_red_triangle:|11:small_red_triangle:|Unique|
|
||||
|Unsqueeze|**1**|1|1|1|1|1|1|1|1|1|**11**|11|Unsqueeze|
|
||||
|Upsample|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|**7**:small_orange_diamond:|7:small_orange_diamond:|**9**:small_orange_diamond:|**10**\*|10\*|10\*|Upsample|
|
||||
|Where|-|-|-|-|-|-|-|-|**9**|9|9|9|Where|
|
||||
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|Xor|
|
||||
|
||||
ONNX-TF Supported Operators / ONNX Operators: 144 / 162
|
||||
|
||||
Notes:
|
||||
1. Cast: Cast string to float32/float64/int32/int64 are not supported in Tensorflow.
|
||||
2. Clip: Clip input in uint64 is not supported in Tensorflow.
|
||||
3. ConvTranspose: ConvTranspose with dilations != 1, or transposed convolution for 4D or higher are not supported in Tensorflow.
|
||||
4. CumSum: CumSum inputs in uint32/uint64 are not supported in Tensorflow.
|
||||
5. Equal: Equal inputs in uint16/uint32/uint64 are not supported in Tensorflow.
|
||||
6. GRU: GRU with clip or GRU with linear_before_reset, or GRU not using sigmoid for z and r, or GRU using Elu as the activation function with alpha != 1, or GRU using HardSigmoid as the activation function with alpha != 0.2 or beta != 0.5 are not supported in TensorFlow.
|
||||
7. LSTM: LSTM not using sigmoid for `f`, or LSTM not using the same activation for `g` and `h` are not supported in Tensorflow.
|
||||
8. MaxPool: MaxPoolWithArgmax with pad is None or incompatible mode, or MaxPoolWithArgmax with 4D or higher input, orMaxPoolWithArgmax with column major are not supported in Tensorflow.
|
||||
9. Mod: Mod Dividend or Divisor in int8/int16/uint8/uint16/uint32/uint64 are not supported in Tensorflow.
|
||||
10. OneHot: OneHot indices in uint16/uint32/uint64/int8/int16/float16/float/double, or OneHot depth in uint8/uint16/uint32/uint64/int8/int16/int64/float16/float/double are not supported in Tensorflow.
|
||||
11. RNN: RNN with clip is not supported in Tensorflow.
|
||||
12. Resize: Resize required 4D input in Tensorflow. For opset 11, only the following attributes and inputs conbination are supported in Tensorflow:
|
||||
1. mode=nearest, coordinate_transformation_mode=align_corners, nearest_mode=round_prefer_ceil, can use scales(*) or sizes.
|
||||
2. mode=nearest, coordinate_transformation_mode=asymmetric, nearest_mode=floor, can use scales(*) or sizes.
|
||||
3. mode=nearest, coordinate_transformation_mode=tf_half_pixel_for_nn, nearest_mode=floor, can use scales(*) or sizes.
|
||||
4. mode=linear, coordinate_transformation_mode=align_corners, can use scales(*) or sizes.
|
||||
5. mode=linear, coordinate_transformation_mode=asymmetric, can use scales(*) or sizes.
|
||||
6. mode=linear, coordinate_transformation_mode=half_pixel, can use scales(*) or sizes.
|
||||
7. mode=cubic, coordinate_transformation_mode=align_corners, cubic_coeff_a=-0.5, exclude_outside=1, can use scales(*) or sizes.
|
||||
8. mode=cubic, coordinate_transformation_mode=asymmetric, cubic_coeff_a=-0.5, exclude_outside=1, can use scales(*) or sizes.
|
||||
9. mode=cubic, coordinate_transformation_mode=half_pixel, cubic_coeff_a=-0.5, exclude_outside=1, can use scales(*) or sizes.
|
||||
10. mode=nearest, coordinate_transformation_mode=tf_crop_and_resize, extrapolation_value=any_float_value, nearest_mode=round_prefer_ceil, can use scales or sizes.
|
||||
11. mode=linear, coordinate_transformation_mode=tf_crop_and_resize, extrapolation_value=any_float_value, can use scales or sizes.
|
||||
- Note (*): The accuracy of your model will go down, if the height and the width of the new sizes(scales * origial sizes) are not in whole numbers.
|
||||
13. Upsample: Upsample required 4D input in Tensorflow.
|
||||
@@ -0,0 +1,170 @@
|
||||
# ONNX-Tensorflow Support Status
|
||||
|||
|
||||
|-:|:-|
|
||||
|ONNX-Tensorflow Version|v1.5.0|
|
||||
|ONNX Version|v1.5.0|
|
||||
|Tensorflow Version|v1.15.0|
|
||||
|
||||
Notes:
|
||||
* Values that are new or updated from a previous opset version are in bold.
|
||||
* -: not defined in corresponding ONNX opset version
|
||||
* \*: the operator is deprecated
|
||||
* :small_red_triangle:: not supported yet
|
||||
* :small_orange_diamond:: partially supported
|
||||
* the rest are all supported
|
||||
|
||||
||||||||||||
|
||||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|
||||
|**ONNX Operator**|**Opset 1**|**Opset 2**|**Opset 3**|**Opset 4**|**Opset 5**|**Opset 6**|**Opset 7**|**Opset 8**|**Opset 9**|**Opset 10**|
|
||||
|Abs|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Acos|-|-|-|-|-|-|**7**|7|7|7|
|
||||
|Acosh|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Add|**1**|1|1|1|1|**6**|**7**|7|7|7|
|
||||
|And|**1**|1|1|1|1|1|**7**|7|7|7|
|
||||
|ArgMax|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ArgMin|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Asin|-|-|-|-|-|-|**7**|7|7|7|
|
||||
|Asinh|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Atan|-|-|-|-|-|-|**7**|7|7|7|
|
||||
|Atanh|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|AveragePool|**1**|1|1|1|1|1|**7**|7|7|**10**|
|
||||
|BatchNormalization|**1**|1|1|1|1|**6**|**7**|7|**9**|9|
|
||||
|Cast|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**9**:small_orange_diamond:|9:small_orange_diamond:|
|
||||
|Ceil|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Clip|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Compress|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|
|
||||
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|
|
||||
|ConstantOfShape|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Conv|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ConvInteger|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|ConvTranspose|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|
|
||||
|Cos|-|-|-|-|-|-|**7**|7|7|7|
|
||||
|Cosh|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|DepthToSpace|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|DequantizeLinear|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|Div|**1**|1|1|1|1|**6**|**7**|7|7|7|
|
||||
|Dropout|**1**|1|1|1|1|**6**|**7**|7|7|**10**|
|
||||
|Elu|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Equal|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|
|
||||
|Erf|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Exp|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Expand|-|-|-|-|-|-|-|**8**|8|8|
|
||||
|EyeLike|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Flatten|**1**|1|1|1|1|1|1|1|**9**|9|
|
||||
|Floor|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|GRU|**1**:small_orange_diamond:|1:small_orange_diamond:|**3**:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|
|
||||
|Gather|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Gemm|**1**|1|1|1|1|**6**|**7**|7|**9**|9|
|
||||
|GlobalAveragePool|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|GlobalLpPool|**1**|**2**|2|2|2|2|2|2|2|2|
|
||||
|GlobalMaxPool|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|
|
||||
|HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Hardmax|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Identity|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|If|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|
|
||||
|InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|IsInf|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|LRN|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|
|
||||
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|
|
||||
|Log|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|LogSoftmax|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Loop|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|
|
||||
|LpNormalization|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|LpPool|**1**:small_red_triangle:|**2**:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|2:small_red_triangle:|
|
||||
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|
|
||||
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|
|
||||
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|
|
||||
|MaxRoiPool|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|
|
||||
|MaxUnpool|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Mean|**1**|1|1|1|1|**6**|6|**8**|8|8|
|
||||
|MeanVarianceNormalization|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Min|**1**|1|1|1|1|**6**|6|**8**|8|8|
|
||||
|Mod|-|-|-|-|-|-|-|-|-|**10**:small_orange_diamond:|
|
||||
|Mul|**1**|1|1|1|1|**6**|**7**|7|7|7|
|
||||
|Multinomial|-|-|-|-|-|-|**7**:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|
|
||||
|Neg|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|NonMaxSuppression|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|NonZero|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Not|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|OneHot|-|-|-|-|-|-|-|-|**9**:small_orange_diamond:|9:small_orange_diamond:|
|
||||
|Or|**1**|1|1|1|1|1|**7**|7|7|7|
|
||||
|PRelu|**1**|1|1|1|1|**6**|**7**|7|**9**|9|
|
||||
|Pad|**1**|**2**|2|2|2|2|2|2|2|2|
|
||||
|Pow|**1**|1|1|1|1|1|**7**|7|7|7|
|
||||
|QLinearConv|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|QLinearMatMul|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|QuantizeLinear|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|RNN|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|
|
||||
|RandomNormal|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|RandomNormalLike|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|RandomUniform|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|RandomUniformLike|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Reciprocal|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|ReduceL1|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceL2|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceLogSum|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceLogSumExp|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceMean|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceProd|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceSum|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|ReduceSumSquare|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Relu|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Reshape|**1**|1|1|1|**5**|5|5|5|5|5|
|
||||
|Resize|-|-|-|-|-|-|-|-|-|**10**:small_orange_diamond:|
|
||||
|ReverseSequence|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|RoiAlign|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|
|
||||
|Scan|-|-|-|-|-|-|-|**8**|**9**|9|
|
||||
|Scatter|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Selu|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Shape|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Shrink|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Sigmoid|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Sign|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Sin|-|-|-|-|-|-|**7**|7|7|7|
|
||||
|Sinh|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Size|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Slice|**1**|1|1|1|1|1|1|1|1|**10**|
|
||||
|Softmax|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Softplus|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Softsign|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|SpaceToDepth|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Split|**1**|**2**|2|2|2|2|2|2|2|2|
|
||||
|Sqrt|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|Squeeze|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|StringNormalizer|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|
|
||||
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|
|
||||
|Sum|**1**|1|1|1|1|**6**|6|**8**|8|8|
|
||||
|Tan|-|-|-|-|-|-|**7**|7|7|7|
|
||||
|Tanh|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|TfIdfVectorizer|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|ThresholdedRelu|-|-|-|-|-|-|-|-|-|**10**|
|
||||
|Tile|**1**|1|1|1|1|**6**|6|6|6|6|
|
||||
|TopK|**1**|1|1|1|1|1|1|1|1|**10**|
|
||||
|Transpose|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Unsqueeze|**1**|1|1|1|1|1|1|1|1|1|
|
||||
|Upsample|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|**7**:small_orange_diamond:|7:small_orange_diamond:|**9**:small_orange_diamond:|**10**\*|
|
||||
|Where|-|-|-|-|-|-|-|-|**9**|9|
|
||||
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|
|
||||
|
||||
ONNX-TF Supported Operators / ONNX Operators: 130 / 137
|
||||
|
||||
Notes:
|
||||
1. Cast: Cast string to float32/float64/int32/int64 are not supported in Tensorflow.
|
||||
2. ConvTranspose: ConvTranspose with dilations != 1, or transposed convolution for 4D or higher are not supported in Tensorflow.
|
||||
3. Equal: Equal inputs in uint16/uint32/uint64 are not supported in Tensorflow.
|
||||
4. GRU: GRU with clip or GRU with linear_before_reset, or GRU not using sigmoid for z and r, or GRU using Elu as the activation function with alpha != 1, or GRU using HardSigmoid as the activation function with alpha != 0.2 or beta != 0.5 are not supported in TensorFlow.
|
||||
5. LSTM: LSTM not using sigmoid for `f`, or LSTM not using the same activation for `g` and `h` are not supported in Tensorflow.
|
||||
6. MaxPool: MaxPoolWithArgmax with pad is None or incompatible mode, or MaxPoolWithArgmax with 4D or higher input, orMaxPoolWithArgmax with column major are not supported in Tensorflow.
|
||||
7. Mod: Mod Dividend or Divisor in int8/int16/uint8/uint16/uint32/uint64 are not supported in Tensorflow.
|
||||
8. OneHot: OneHot indices in uint16/uint32/uint64/int8/int16/float16/float/double, or OneHot depth in uint8/uint16/uint32/uint64/int8/int16/int64/float16/float/double are not supported in Tensorflow.
|
||||
9. RNN: RNN with clip is not supported in Tensorflow.
|
||||
10. Resize: Resize required 4D input in Tensorflow.
|
||||
11. Upsample: Upsample required 4D input in Tensorflow.
|
||||
@@ -0,0 +1,203 @@
|
||||
# ONNX-Tensorflow Support Status
|
||||
|||
|
||||
|-:|:-|
|
||||
|ONNX-Tensorflow Version|v1.6.0|
|
||||
|ONNX Version|v1.6.0|
|
||||
|Tensorflow Version|v1.15.0|
|
||||
|
||||
Notes:
|
||||
* Values that are new or updated from a previous opset version are in bold.
|
||||
* -: not defined in corresponding ONNX opset version
|
||||
* \*: the operator is deprecated
|
||||
* :small_red_triangle:: not supported yet
|
||||
* :small_orange_diamond:: partially supported
|
||||
* the rest are all supported
|
||||
|
||||
||||||||||||||
|
||||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|
||||
|**ONNX Operator**|**Opset 1**|**Opset 2**|**Opset 3**|**Opset 4**|**Opset 5**|**Opset 6**|**Opset 7**|**Opset 8**|**Opset 9**|**Opset 10**|**Opset 11**|**ONNX Operator**|
|
||||
|Abs|**1**|1|1|1|1|**6**|6|6|6|6|6|Abs|
|
||||
|Acos|-|-|-|-|-|-|**7**|7|7|7|7|Acos|
|
||||
|Acosh|-|-|-|-|-|-|-|-|**9**|9|9|Acosh|
|
||||
|Add|**1**|1|1|1|1|**6**|**7**|7|7|7|7|Add|
|
||||
|And|**1**|1|1|1|1|1|**7**|7|7|7|7|And|
|
||||
|ArgMax|**1**|1|1|1|1|1|1|1|1|1|**11**|ArgMax|
|
||||
|ArgMin|**1**|1|1|1|1|1|1|1|1|1|**11**|ArgMin|
|
||||
|Asin|-|-|-|-|-|-|**7**|7|7|7|7|Asin|
|
||||
|Asinh|-|-|-|-|-|-|-|-|**9**|9|9|Asinh|
|
||||
|Atan|-|-|-|-|-|-|**7**|7|7|7|7|Atan|
|
||||
|Atanh|-|-|-|-|-|-|-|-|**9**|9|9|Atanh|
|
||||
|AveragePool|**1**|1|1|1|1|1|**7**|7|7|**10**|**11**|AveragePool|
|
||||
|BatchNormalization|**1**|1|1|1|1|**6**|**7**|7|**9**|9|9|BatchNormalization|
|
||||
|BitShift|-|-|-|-|-|-|-|-|-|-|**11**|BitShift|
|
||||
|Cast|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**9**:small_orange_diamond:|9:small_orange_diamond:|9:small_orange_diamond:|Cast|
|
||||
|Ceil|**1**|1|1|1|1|**6**|6|6|6|6|6|Ceil|
|
||||
|Clip|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**6**:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|6:small_orange_diamond:|**11**:small_orange_diamond:|Clip|
|
||||
|Compress|-|-|-|-|-|-|-|-|**9**|9|**11**|Compress|
|
||||
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|Concat|
|
||||
|ConcatFromSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_red_triangle:|ConcatFromSequence|
|
||||
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|Constant|
|
||||
|ConstantOfShape|-|-|-|-|-|-|-|-|**9**|9|9|ConstantOfShape|
|
||||
|Conv|**1**|1|1|1|1|1|1|1|1|1|**11**|Conv|
|
||||
|ConvInteger|-|-|-|-|-|-|-|-|-|**10**|10|ConvInteger|
|
||||
|ConvTranspose|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**11**:small_orange_diamond:|ConvTranspose|
|
||||
|Cos|-|-|-|-|-|-|**7**|7|7|7|7|Cos|
|
||||
|Cosh|-|-|-|-|-|-|-|-|**9**|9|9|Cosh|
|
||||
|CumSum|-|-|-|-|-|-|-|-|-|-|**11**:small_orange_diamond:|CumSum|
|
||||
|DepthToSpace|**1**|1|1|1|1|1|1|1|1|1|**11**|DepthToSpace|
|
||||
|DequantizeLinear|-|-|-|-|-|-|-|-|-|**10**|10|DequantizeLinear|
|
||||
|Det|-|-|-|-|-|-|-|-|-|-|**11**|Det|
|
||||
|Div|**1**|1|1|1|1|**6**|**7**|7|7|7|7|Div|
|
||||
|Dropout|**1**|1|1|1|1|**6**|**7**|7|7|**10**|10|Dropout|
|
||||
|DynamicQuantizeLinear|-|-|-|-|-|-|-|-|-|-|**11**|DynamicQuantizeLinear|
|
||||
|Elu|**1**|1|1|1|1|**6**|6|6|6|6|6|Elu|
|
||||
|Equal|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**11**:small_orange_diamond:|Equal|
|
||||
|Erf|-|-|-|-|-|-|-|-|**9**|9|9|Erf|
|
||||
|Exp|**1**|1|1|1|1|**6**|6|6|6|6|6|Exp|
|
||||
|Expand|-|-|-|-|-|-|-|**8**|8|8|8|Expand|
|
||||
|EyeLike|-|-|-|-|-|-|-|-|**9**|9|9|EyeLike|
|
||||
|Flatten|**1**|1|1|1|1|1|1|1|**9**|9|**11**|Flatten|
|
||||
|Floor|**1**|1|1|1|1|**6**|6|6|6|6|6|Floor|
|
||||
|GRU|**1**:small_orange_diamond:|1:small_orange_diamond:|**3**:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|GRU|
|
||||
|Gather|**1**|1|1|1|1|1|1|1|1|1|**11**|Gather|
|
||||
|GatherElements|-|-|-|-|-|-|-|-|-|-|**11**|GatherElements|
|
||||
|GatherND|-|-|-|-|-|-|-|-|-|-|**11**|GatherND|
|
||||
|Gemm|**1**|1|1|1|1|**6**|**7**|7|**9**|9|**11**|Gemm|
|
||||
|GlobalAveragePool|**1**|1|1|1|1|1|1|1|1|1|1|GlobalAveragePool|
|
||||
|GlobalLpPool|**1**|**2**|2|2|2|2|2|2|2|2|2|GlobalLpPool|
|
||||
|GlobalMaxPool|**1**|1|1|1|1|1|1|1|1|1|1|GlobalMaxPool|
|
||||
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|Greater|
|
||||
|HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|HardSigmoid|
|
||||
|Hardmax|**1**|1|1|1|1|1|1|1|1|1|**11**|Hardmax|
|
||||
|Identity|**1**|1|1|1|1|1|1|1|1|1|1|Identity|
|
||||
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|If|
|
||||
|InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|6|InstanceNormalization|
|
||||
|IsInf|-|-|-|-|-|-|-|-|-|**10**|10|IsInf|
|
||||
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|IsNaN|
|
||||
|LRN|**1**|1|1|1|1|1|1|1|1|1|1|LRN|
|
||||
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|LSTM|
|
||||
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|LeakyRelu|
|
||||
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|Less|
|
||||
|Log|**1**|1|1|1|1|**6**|6|6|6|6|6|Log|
|
||||
|LogSoftmax|**1**|1|1|1|1|1|1|1|1|1|**11**|LogSoftmax|
|
||||
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|Loop|
|
||||
|LpNormalization|**1**|1|1|1|1|1|1|1|1|1|1|LpNormalization|
|
||||
|LpPool|**1**|**2**|2|2|2|2|2|2|2|2|**11**|LpPool|
|
||||
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|MatMul|
|
||||
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**|10|MatMulInteger|
|
||||
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|8|Max|
|
||||
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|**11**:small_orange_diamond:|MaxPool|
|
||||
|MaxRoiPool|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|MaxRoiPool|
|
||||
|MaxUnpool|-|-|-|-|-|-|-|-|**9**|9|**11**|MaxUnpool|
|
||||
|Mean|**1**|1|1|1|1|**6**|6|**8**|8|8|8|Mean|
|
||||
|MeanVarianceNormalization|-|-|-|-|-|-|-|-|**9**|9|9|MeanVarianceNormalization|
|
||||
|Min|**1**|1|1|1|1|**6**|6|**8**|8|8|8|Min|
|
||||
|Mod|-|-|-|-|-|-|-|-|-|**10**:small_orange_diamond:|10:small_orange_diamond:|Mod|
|
||||
|Mul|**1**|1|1|1|1|**6**|**7**|7|7|7|7|Mul|
|
||||
|Multinomial|-|-|-|-|-|-|**7**:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|Multinomial|
|
||||
|Neg|**1**|1|1|1|1|**6**|6|6|6|6|6|Neg|
|
||||
|NonMaxSuppression|-|-|-|-|-|-|-|-|-|**10**|**11**|NonMaxSuppression|
|
||||
|NonZero|-|-|-|-|-|-|-|-|**9**|9|9|NonZero|
|
||||
|Not|**1**|1|1|1|1|1|1|1|1|1|1|Not|
|
||||
|OneHot|-|-|-|-|-|-|-|-|**9**:small_orange_diamond:|9:small_orange_diamond:|**11**:small_orange_diamond:|OneHot|
|
||||
|Or|**1**|1|1|1|1|1|**7**|7|7|7|7|Or|
|
||||
|PRelu|**1**|1|1|1|1|**6**|**7**|7|**9**|9|9|PRelu|
|
||||
|Pad|**1**|**2**|2|2|2|2|2|2|2|2|**11**|Pad|
|
||||
|Pow|**1**|1|1|1|1|1|**7**|7|7|7|7|Pow|
|
||||
|QLinearConv|-|-|-|-|-|-|-|-|-|**10**|10|QLinearConv|
|
||||
|QLinearMatMul|-|-|-|-|-|-|-|-|-|**10**|10|QLinearMatMul|
|
||||
|QuantizeLinear|-|-|-|-|-|-|-|-|-|**10**|10|QuantizeLinear|
|
||||
|RNN|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|RNN|
|
||||
|RandomNormal|**1**|1|1|1|1|1|1|1|1|1|1|RandomNormal|
|
||||
|RandomNormalLike|**1**|1|1|1|1|1|1|1|1|1|1|RandomNormalLike|
|
||||
|RandomUniform|**1**|1|1|1|1|1|1|1|1|1|1|RandomUniform|
|
||||
|RandomUniformLike|**1**|1|1|1|1|1|1|1|1|1|1|RandomUniformLike|
|
||||
|Range|-|-|-|-|-|-|-|-|-|-|**11**|Range|
|
||||
|Reciprocal|**1**|1|1|1|1|**6**|6|6|6|6|6|Reciprocal|
|
||||
|ReduceL1|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceL1|
|
||||
|ReduceL2|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceL2|
|
||||
|ReduceLogSum|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceLogSum|
|
||||
|ReduceLogSumExp|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceLogSumExp|
|
||||
|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceMax|
|
||||
|ReduceMean|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceMean|
|
||||
|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceMin|
|
||||
|ReduceProd|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceProd|
|
||||
|ReduceSum|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceSum|
|
||||
|ReduceSumSquare|**1**|1|1|1|1|1|1|1|1|1|**11**|ReduceSumSquare|
|
||||
|Relu|**1**|1|1|1|1|**6**|6|6|6|6|6|Relu|
|
||||
|Reshape|**1**|1|1|1|**5**|5|5|5|5|5|5|Reshape|
|
||||
|Resize|-|-|-|-|-|-|-|-|-|**10**:small_orange_diamond:|**11**:small_orange_diamond:|Resize|
|
||||
|ReverseSequence|-|-|-|-|-|-|-|-|-|**10**|10|ReverseSequence|
|
||||
|RoiAlign|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|RoiAlign|
|
||||
|Round|-|-|-|-|-|-|-|-|-|-|**11**|Round|
|
||||
|Scan|-|-|-|-|-|-|-|**8**|**9**|9|**11**|Scan|
|
||||
|Scatter|-|-|-|-|-|-|-|-|**9**|9|**11**\*|Scatter|
|
||||
|ScatterElements|-|-|-|-|-|-|-|-|-|-|**11**|ScatterElements|
|
||||
|ScatterND|-|-|-|-|-|-|-|-|-|-|**11**|ScatterND|
|
||||
|Selu|**1**|1|1|1|1|**6**|6|6|6|6|6|Selu|
|
||||
|SequenceAt|-|-|-|-|-|-|-|-|-|-|**11**|SequenceAt|
|
||||
|SequenceConstruct|-|-|-|-|-|-|-|-|-|-|**11**|SequenceConstruct|
|
||||
|SequenceEmpty|-|-|-|-|-|-|-|-|-|-|**11**|SequenceEmpty|
|
||||
|SequenceErase|-|-|-|-|-|-|-|-|-|-|**11**|SequenceErase|
|
||||
|SequenceInsert|-|-|-|-|-|-|-|-|-|-|**11**|SequenceInsert|
|
||||
|SequenceLength|-|-|-|-|-|-|-|-|-|-|**11**|SequenceLength|
|
||||
|Shape|**1**|1|1|1|1|1|1|1|1|1|1|Shape|
|
||||
|Shrink|-|-|-|-|-|-|-|-|**9**|9|9|Shrink|
|
||||
|Sigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|Sigmoid|
|
||||
|Sign|-|-|-|-|-|-|-|-|**9**|9|9|Sign|
|
||||
|Sin|-|-|-|-|-|-|**7**|7|7|7|7|Sin|
|
||||
|Sinh|-|-|-|-|-|-|-|-|**9**|9|9|Sinh|
|
||||
|Size|**1**|1|1|1|1|1|1|1|1|1|1|Size|
|
||||
|Slice|**1**|1|1|1|1|1|1|1|1|**10**|**11**|Slice|
|
||||
|Softmax|**1**|1|1|1|1|1|1|1|1|1|**11**|Softmax|
|
||||
|Softplus|**1**|1|1|1|1|1|1|1|1|1|1|Softplus|
|
||||
|Softsign|**1**|1|1|1|1|1|1|1|1|1|1|Softsign|
|
||||
|SpaceToDepth|**1**|1|1|1|1|1|1|1|1|1|1|SpaceToDepth|
|
||||
|Split|**1**|**2**|2|2|2|2|2|2|2|2|**11**|Split|
|
||||
|SplitToSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_red_triangle:|SplitToSequence|
|
||||
|Sqrt|**1**|1|1|1|1|**6**|6|6|6|6|6|Sqrt|
|
||||
|Squeeze|**1**|1|1|1|1|1|1|1|1|1|**11**|Squeeze|
|
||||
|StringNormalizer|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|StringNormalizer|
|
||||
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|Sub|
|
||||
|Sum|**1**|1|1|1|1|**6**|6|**8**|8|8|8|Sum|
|
||||
|Tan|-|-|-|-|-|-|**7**|7|7|7|7|Tan|
|
||||
|Tanh|**1**|1|1|1|1|**6**|6|6|6|6|6|Tanh|
|
||||
|TfIdfVectorizer|-|-|-|-|-|-|-|-|**9**|9|9|TfIdfVectorizer|
|
||||
|ThresholdedRelu|-|-|-|-|-|-|-|-|-|**10**|10|ThresholdedRelu|
|
||||
|Tile|**1**|1|1|1|1|**6**|6|6|6|6|6|Tile|
|
||||
|TopK|**1**|1|1|1|1|1|1|1|1|**10**|**11**|TopK|
|
||||
|Transpose|**1**|1|1|1|1|1|1|1|1|1|1|Transpose|
|
||||
|Unique|-|-|-|-|-|-|-|-|-|-|**11**:small_red_triangle:|Unique|
|
||||
|Unsqueeze|**1**|1|1|1|1|1|1|1|1|1|**11**|Unsqueeze|
|
||||
|Upsample|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|**7**:small_orange_diamond:|7:small_orange_diamond:|**9**:small_orange_diamond:|**10**\*|10\*|Upsample|
|
||||
|Where|-|-|-|-|-|-|-|-|**9**|9|9|Where|
|
||||
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|Xor|
|
||||
|
||||
ONNX-TF Supported Operators / ONNX Operators: 149 / 156
|
||||
|
||||
Notes:
|
||||
1. Cast: Cast string to float32/float64/int32/int64 are not supported in Tensorflow.
|
||||
2. Clip: Clip input in uint64 is not supported in Tensorflow.
|
||||
3. ConvTranspose: ConvTranspose with dilations != 1, or transposed convolution for 4D or higher are not supported in Tensorflow.
|
||||
4. CumSum: CumSum inputs in uint32/uint64 are not supported in Tensorflow.
|
||||
5. Equal: Equal inputs in uint16/uint32/uint64 are not supported in Tensorflow.
|
||||
6. GRU: GRU with clip or GRU with linear_before_reset, or GRU not using sigmoid for z and r, or GRU using Elu as the activation function with alpha != 1, or GRU using HardSigmoid as the activation function with alpha != 0.2 or beta != 0.5 are not supported in TensorFlow.
|
||||
7. LSTM: LSTM not using sigmoid for `f`, or LSTM not using the same activation for `g` and `h` are not supported in Tensorflow.
|
||||
8. MaxPool: MaxPoolWithArgmax with pad is None or incompatible mode, or MaxPoolWithArgmax with 4D or higher input, orMaxPoolWithArgmax with column major are not supported in Tensorflow.
|
||||
9. Mod: Mod Dividend or Divisor in int8/int16/uint8/uint16/uint32/uint64 are not supported in Tensorflow.
|
||||
10. OneHot: OneHot indices in uint16/uint32/uint64/int8/int16/float16/float/double, or OneHot depth in uint8/uint16/uint32/uint64/int8/int16/int64/float16/float/double are not supported in Tensorflow.
|
||||
11. RNN: RNN with clip is not supported in Tensorflow.
|
||||
12. Resize: Resize required 4D input in Tensorflow. For opset 11, only the following attributes and inputs conbination are supported in Tensorflow:
|
||||
1. mode=nearest, coordinate_transformation_mode=align_corners, nearest_mode=round_prefer_ceil, can use scales(*) or sizes.
|
||||
2. mode=nearest, coordinate_transformation_mode=asymmetric, nearest_mode=floor, can use scales(*) or sizes.
|
||||
3. mode=nearest, coordinate_transformation_mode=tf_half_pixel_for_nn, nearest_mode=floor, can use scales(*) or sizes.
|
||||
4. mode=linear, coordinate_transformation_mode=align_corners, can use scales(*) or sizes.
|
||||
5. mode=linear, coordinate_transformation_mode=asymmetric, can use scales(*) or sizes.
|
||||
6. mode=linear, coordinate_transformation_mode=half_pixel, can use scales(*) or sizes.
|
||||
7. mode=cubic, coordinate_transformation_mode=align_corners, cubic_coeff_a=-0.5, exclude_outside=1, can use scales(*) or sizes.
|
||||
8. mode=cubic, coordinate_transformation_mode=asymmetric, cubic_coeff_a=-0.5, exclude_outside=1, can use scales(*) or sizes.
|
||||
9. mode=cubic, coordinate_transformation_mode=half_pixel, cubic_coeff_a=-0.5, exclude_outside=1, can use scales(*) or sizes.
|
||||
10. mode=nearest, coordinate_transformation_mode=tf_crop_and_resize, extrapolation_value=any_float_value, nearest_mode=round_prefer_ceil, can use scales or sizes.
|
||||
11. mode=linear, coordinate_transformation_mode=tf_crop_and_resize, extrapolation_value=any_float_value, can use scales or sizes.
|
||||
- Note (*): The accuracy of your model will go down, if the height and the width of the new sizes(scales * origial sizes) are not in whole numbers.
|
||||
13. Upsample: Upsample required 4D input in Tensorflow.
|
||||
@@ -0,0 +1,7 @@
|
||||
import onnx
|
||||
|
||||
from onnx_tf.backend import prepare
|
||||
|
||||
onnx_model = onnx.load("input_path") # load onnx model
|
||||
tf_rep = prepare(onnx_model) # prepare tf representation
|
||||
tf_rep.export_graph("output_path") # export the model
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from onnx_tf.backend import run_node, prepare
|
||||
from onnx import helper
|
||||
|
||||
node_def = helper.make_node("Relu", ["X"], ["Y"])
|
||||
output = run_node(node_def, [[-0.1, 0.1]])
|
||||
print(output["Y"])
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import caffe2.python.onnx.backend as c2
|
||||
import onnx
|
||||
import onnx_tf.backend as tf
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
|
||||
def find_between(s, first, last):
|
||||
try:
|
||||
start = s.index(first) + len(first)
|
||||
end = s.index(last, start)
|
||||
return s[start:end]
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
class TestLargeModel(unittest.TestCase):
|
||||
MODEL_PATH = "../../onnx_models/"
|
||||
|
||||
def test(self):
|
||||
_model = onnx.load(self.MODEL_PATH + "shufflenet/model.onnx")
|
||||
node_count = len(_model.graph.node)
|
||||
more_outputs = []
|
||||
output_to_check = []
|
||||
for node in _model.graph.node:
|
||||
more_outputs.append(
|
||||
helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT,
|
||||
(100, 100)))
|
||||
output_to_check.append(node.output[0])
|
||||
_model.graph.output.extend(more_outputs)
|
||||
|
||||
tf_rep = tf.prepare(_model)
|
||||
cf_rep = c2.prepare(_model)
|
||||
|
||||
sample = np.load(
|
||||
self.MODEL_PATH + "shufflenet/test_data_{}.npz".format(str(1)),
|
||||
encoding='bytes')
|
||||
inputs = list(sample['inputs'])
|
||||
outputs = list(sample['outputs'])
|
||||
|
||||
my_out = tf_rep.run(inputs)
|
||||
cf_out = cf_rep.run(inputs)
|
||||
|
||||
for op in output_to_check:
|
||||
try:
|
||||
np.savetxt(
|
||||
op.replace("/", "__") + ".cf", cf_out[op].flatten(), delimiter='\t')
|
||||
np.savetxt(
|
||||
op.replace("/", "__") + ".tf", my_out[op].flatten(), delimiter='\t')
|
||||
np.testing.assert_allclose(my_out[op], cf_out[op], rtol=1e-2)
|
||||
print(op, "results of this layer are correct within tolerence.")
|
||||
except Exception as e:
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
mismatch_percent = (find_between(str(e), "(mismatch", "%)"))
|
||||
print(op, "mismatch with percentage {} %".format(mismatch_percent))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
pass
|
||||
@@ -0,0 +1,12 @@
|
||||
Metadata-Version: 1.1
|
||||
Name: onnx-tf
|
||||
Version: 1.6.0
|
||||
Summary: Tensorflow backend for ONNX (Open Neural Network Exchange).
|
||||
Home-page: https://github.com/onnx/onnx-tensorflow/
|
||||
Author: Arpith Jacob, Tian Jin, Gheorghe-Teodor Bercea, Wenhao Hu
|
||||
Author-email: tian.jin1@ibm.com
|
||||
License: Apache License 2.0
|
||||
Description: UNKNOWN
|
||||
Platform: UNKNOWN
|
||||
Classifier: Programming Language :: Python :: 2
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
@@ -0,0 +1,209 @@
|
||||
LICENSE
|
||||
MANIFEST.in
|
||||
ONNX_VERSION_NUMBER
|
||||
README.md
|
||||
VERSION_NUMBER
|
||||
setup.cfg
|
||||
setup.py
|
||||
onnx_tf/__init__.py
|
||||
onnx_tf/backend.py
|
||||
onnx_tf/backend_rep.py
|
||||
onnx_tf/cli.py
|
||||
onnx_tf/converter.py
|
||||
onnx_tf/gen_doc.py
|
||||
onnx_tf/gen_opset.py
|
||||
onnx_tf/gen_status.py
|
||||
onnx_tf/opset_version.py
|
||||
onnx_tf/pb_wrapper.py
|
||||
onnx_tf/version.py
|
||||
onnx_tf.egg-info/PKG-INFO
|
||||
onnx_tf.egg-info/SOURCES.txt
|
||||
onnx_tf.egg-info/dependency_links.txt
|
||||
onnx_tf.egg-info/entry_points.txt
|
||||
onnx_tf.egg-info/not-zip-safe
|
||||
onnx_tf.egg-info/requires.txt
|
||||
onnx_tf.egg-info/top_level.txt
|
||||
onnx_tf/common/__init__.py
|
||||
onnx_tf/common/attr_converter.py
|
||||
onnx_tf/common/attr_translator.py
|
||||
onnx_tf/common/data_type.py
|
||||
onnx_tf/common/exception.py
|
||||
onnx_tf/common/handler_helper.py
|
||||
onnx_tf/common/legacy.py
|
||||
onnx_tf/common/pooling_helper.py
|
||||
onnx_tf/common/tf_helper.py
|
||||
onnx_tf/handlers/__init__.py
|
||||
onnx_tf/handlers/backend_handler.py
|
||||
onnx_tf/handlers/handler.py
|
||||
onnx_tf/handlers/backend/__init__.py
|
||||
onnx_tf/handlers/backend/abs.py
|
||||
onnx_tf/handlers/backend/acos.py
|
||||
onnx_tf/handlers/backend/acosh.py
|
||||
onnx_tf/handlers/backend/add.py
|
||||
onnx_tf/handlers/backend/and.py
|
||||
onnx_tf/handlers/backend/arg_max.py
|
||||
onnx_tf/handlers/backend/arg_min.py
|
||||
onnx_tf/handlers/backend/asin.py
|
||||
onnx_tf/handlers/backend/asinh.py
|
||||
onnx_tf/handlers/backend/atan.py
|
||||
onnx_tf/handlers/backend/atanh.py
|
||||
onnx_tf/handlers/backend/average_pool.py
|
||||
onnx_tf/handlers/backend/batch_normalization.py
|
||||
onnx_tf/handlers/backend/bitshift.py
|
||||
onnx_tf/handlers/backend/broadcast_mixin.py
|
||||
onnx_tf/handlers/backend/cast.py
|
||||
onnx_tf/handlers/backend/ceil.py
|
||||
onnx_tf/handlers/backend/clip.py
|
||||
onnx_tf/handlers/backend/compress.py
|
||||
onnx_tf/handlers/backend/concat.py
|
||||
onnx_tf/handlers/backend/constant.py
|
||||
onnx_tf/handlers/backend/constant_fill.py
|
||||
onnx_tf/handlers/backend/constant_of_shape.py
|
||||
onnx_tf/handlers/backend/control_flow_mixin.py
|
||||
onnx_tf/handlers/backend/conv.py
|
||||
onnx_tf/handlers/backend/conv_integer.py
|
||||
onnx_tf/handlers/backend/conv_mixin.py
|
||||
onnx_tf/handlers/backend/conv_transpose.py
|
||||
onnx_tf/handlers/backend/cos.py
|
||||
onnx_tf/handlers/backend/cosh.py
|
||||
onnx_tf/handlers/backend/cumsum.py
|
||||
onnx_tf/handlers/backend/depth_to_space.py
|
||||
onnx_tf/handlers/backend/dequantize_linear.py
|
||||
onnx_tf/handlers/backend/det.py
|
||||
onnx_tf/handlers/backend/dilated_pooling.py
|
||||
onnx_tf/handlers/backend/div.py
|
||||
onnx_tf/handlers/backend/dropout.py
|
||||
onnx_tf/handlers/backend/dynamic_quantize_linear.py
|
||||
onnx_tf/handlers/backend/elu.py
|
||||
onnx_tf/handlers/backend/equal.py
|
||||
onnx_tf/handlers/backend/erf.py
|
||||
onnx_tf/handlers/backend/exp.py
|
||||
onnx_tf/handlers/backend/expand.py
|
||||
onnx_tf/handlers/backend/eye_like.py
|
||||
onnx_tf/handlers/backend/flatten.py
|
||||
onnx_tf/handlers/backend/floor.py
|
||||
onnx_tf/handlers/backend/gather.py
|
||||
onnx_tf/handlers/backend/gather_and_scatter_mixin.py
|
||||
onnx_tf/handlers/backend/gather_elements.py
|
||||
onnx_tf/handlers/backend/gather_nd.py
|
||||
onnx_tf/handlers/backend/gemm.py
|
||||
onnx_tf/handlers/backend/global_average_pool.py
|
||||
onnx_tf/handlers/backend/global_lp_pool.py
|
||||
onnx_tf/handlers/backend/global_max_pool.py
|
||||
onnx_tf/handlers/backend/greater.py
|
||||
onnx_tf/handlers/backend/gru.py
|
||||
onnx_tf/handlers/backend/hard_sigmoid.py
|
||||
onnx_tf/handlers/backend/hardmax.py
|
||||
onnx_tf/handlers/backend/identity.py
|
||||
onnx_tf/handlers/backend/if.py
|
||||
onnx_tf/handlers/backend/image_scaler.py
|
||||
onnx_tf/handlers/backend/instance_normalization.py
|
||||
onnx_tf/handlers/backend/is_inf.py
|
||||
onnx_tf/handlers/backend/is_nan.py
|
||||
onnx_tf/handlers/backend/leaky_relu.py
|
||||
onnx_tf/handlers/backend/less.py
|
||||
onnx_tf/handlers/backend/log.py
|
||||
onnx_tf/handlers/backend/log_softmax.py
|
||||
onnx_tf/handlers/backend/loop.py
|
||||
onnx_tf/handlers/backend/lp_normalization.py
|
||||
onnx_tf/handlers/backend/lp_pool.py
|
||||
onnx_tf/handlers/backend/lrn.py
|
||||
onnx_tf/handlers/backend/lstm.py
|
||||
onnx_tf/handlers/backend/mat_mul.py
|
||||
onnx_tf/handlers/backend/mat_mul_integer.py
|
||||
onnx_tf/handlers/backend/math_mixin.py
|
||||
onnx_tf/handlers/backend/max.py
|
||||
onnx_tf/handlers/backend/max_pool.py
|
||||
onnx_tf/handlers/backend/max_unpool.py
|
||||
onnx_tf/handlers/backend/mean.py
|
||||
onnx_tf/handlers/backend/mean_variance_normalization.py
|
||||
onnx_tf/handlers/backend/min.py
|
||||
onnx_tf/handlers/backend/mod.py
|
||||
onnx_tf/handlers/backend/mul.py
|
||||
onnx_tf/handlers/backend/neg.py
|
||||
onnx_tf/handlers/backend/non_max_suppression.py
|
||||
onnx_tf/handlers/backend/non_zero.py
|
||||
onnx_tf/handlers/backend/not.py
|
||||
onnx_tf/handlers/backend/onehot.py
|
||||
onnx_tf/handlers/backend/or.py
|
||||
onnx_tf/handlers/backend/p_relu.py
|
||||
onnx_tf/handlers/backend/pad.py
|
||||
onnx_tf/handlers/backend/pad_mixin.py
|
||||
onnx_tf/handlers/backend/pool_mixin.py
|
||||
onnx_tf/handlers/backend/pow.py
|
||||
onnx_tf/handlers/backend/q_linear_conv.py
|
||||
onnx_tf/handlers/backend/q_linear_mat_mul.py
|
||||
onnx_tf/handlers/backend/quantize_linear.py
|
||||
onnx_tf/handlers/backend/random_normal.py
|
||||
onnx_tf/handlers/backend/random_normal_like.py
|
||||
onnx_tf/handlers/backend/random_uniform.py
|
||||
onnx_tf/handlers/backend/random_uniform_like.py
|
||||
onnx_tf/handlers/backend/range.py
|
||||
onnx_tf/handlers/backend/reciprocal.py
|
||||
onnx_tf/handlers/backend/reduce_l1.py
|
||||
onnx_tf/handlers/backend/reduce_l2.py
|
||||
onnx_tf/handlers/backend/reduce_log_sum.py
|
||||
onnx_tf/handlers/backend/reduce_log_sum_exp.py
|
||||
onnx_tf/handlers/backend/reduce_max.py
|
||||
onnx_tf/handlers/backend/reduce_mean.py
|
||||
onnx_tf/handlers/backend/reduce_min.py
|
||||
onnx_tf/handlers/backend/reduce_prod.py
|
||||
onnx_tf/handlers/backend/reduce_sum.py
|
||||
onnx_tf/handlers/backend/reduce_sum_square.py
|
||||
onnx_tf/handlers/backend/relu.py
|
||||
onnx_tf/handlers/backend/reshape.py
|
||||
onnx_tf/handlers/backend/resize.py
|
||||
onnx_tf/handlers/backend/reverse_sequence.py
|
||||
onnx_tf/handlers/backend/rnn.py
|
||||
onnx_tf/handlers/backend/rnn_mixin.py
|
||||
onnx_tf/handlers/backend/round.py
|
||||
onnx_tf/handlers/backend/scan.py
|
||||
onnx_tf/handlers/backend/scan_mixin.py
|
||||
onnx_tf/handlers/backend/scatter.py
|
||||
onnx_tf/handlers/backend/scatter_elements.py
|
||||
onnx_tf/handlers/backend/scatter_nd.py
|
||||
onnx_tf/handlers/backend/selu.py
|
||||
onnx_tf/handlers/backend/sequence_at.py
|
||||
onnx_tf/handlers/backend/sequence_construct.py
|
||||
onnx_tf/handlers/backend/sequence_empty.py
|
||||
onnx_tf/handlers/backend/sequence_erase.py
|
||||
onnx_tf/handlers/backend/sequence_insert.py
|
||||
onnx_tf/handlers/backend/sequence_length.py
|
||||
onnx_tf/handlers/backend/shape.py
|
||||
onnx_tf/handlers/backend/shrink.py
|
||||
onnx_tf/handlers/backend/sigmoid.py
|
||||
onnx_tf/handlers/backend/sign.py
|
||||
onnx_tf/handlers/backend/sin.py
|
||||
onnx_tf/handlers/backend/sinh.py
|
||||
onnx_tf/handlers/backend/size.py
|
||||
onnx_tf/handlers/backend/slice.py
|
||||
onnx_tf/handlers/backend/softmax.py
|
||||
onnx_tf/handlers/backend/softplus.py
|
||||
onnx_tf/handlers/backend/softsign.py
|
||||
onnx_tf/handlers/backend/space_to_depth.py
|
||||
onnx_tf/handlers/backend/split.py
|
||||
onnx_tf/handlers/backend/sqrt.py
|
||||
onnx_tf/handlers/backend/squeeze.py
|
||||
onnx_tf/handlers/backend/sub.py
|
||||
onnx_tf/handlers/backend/sum.py
|
||||
onnx_tf/handlers/backend/tan.py
|
||||
onnx_tf/handlers/backend/tanh.py
|
||||
onnx_tf/handlers/backend/tfidf_vectorizer.py
|
||||
onnx_tf/handlers/backend/thresholded_relu.py
|
||||
onnx_tf/handlers/backend/tile.py
|
||||
onnx_tf/handlers/backend/top_k.py
|
||||
onnx_tf/handlers/backend/transpose.py
|
||||
onnx_tf/handlers/backend/unpool_mixin.py
|
||||
onnx_tf/handlers/backend/unsqueeze.py
|
||||
onnx_tf/handlers/backend/upsample.py
|
||||
onnx_tf/handlers/backend/where.py
|
||||
onnx_tf/handlers/backend/xor.py
|
||||
test/__init__.py
|
||||
test/test_cli.py
|
||||
test/backend/__init__.py
|
||||
test/backend/test_dynamic_shape.py
|
||||
test/backend/test_model.py
|
||||
test/backend/test_node.py
|
||||
test/backend/test_onnx_backend.py
|
||||
third_party/__init__.py
|
||||
third_party/get_info.py
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
[console_scripts]
|
||||
onnx-tf = onnx_tf.cli:main
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
onnx>=1.6.0
|
||||
PyYAML
|
||||
@@ -0,0 +1,3 @@
|
||||
onnx_tf
|
||||
test
|
||||
third_party
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import backend
|
||||
from .version import version as __version__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,352 @@
|
||||
"""Backend for running ONNX on Tensorflow
|
||||
|
||||
To run this, you will need to have Tensorflow installed as well.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
try:
|
||||
from itertools import izip as zip
|
||||
except ImportError: # will be 3.x series
|
||||
pass
|
||||
|
||||
from onnx import defs
|
||||
from onnx import numpy_helper
|
||||
from onnx.backend.base import Backend
|
||||
from onnx.backend.base import Device
|
||||
from onnx.backend.base import namedtupledict
|
||||
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
|
||||
from onnx.helper import make_opsetid
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.backend_rep import TensorflowRep
|
||||
from onnx_tf.common import data_type
|
||||
from onnx_tf.common import get_device_option
|
||||
from onnx_tf.common import get_unique_suffix
|
||||
from onnx_tf.common import supports_device as common_supports_device
|
||||
from onnx_tf.common.handler_helper import get_all_backend_handlers
|
||||
from onnx_tf.pb_wrapper import OnnxNode
|
||||
import onnx_tf.common as common
|
||||
|
||||
|
||||
class TensorflowBackend(Backend):
|
||||
""" Tensorflow Backend for ONNX
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def prepare(cls,
|
||||
model,
|
||||
device='CPU',
|
||||
strict=True,
|
||||
logging_level='INFO',
|
||||
**kwargs):
|
||||
"""Prepare an ONNX model for Tensorflow Backend.
|
||||
|
||||
This function converts an ONNX model to an internel representation
|
||||
of the computational graph called TensorflowRep and returns
|
||||
the converted representation.
|
||||
|
||||
:param model: The ONNX model to be converted.
|
||||
:param device: The device to execute this model on.
|
||||
:param strict: Whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
||||
Changing to False is strongly discouraged.
|
||||
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
|
||||
:param logging_level: The logging level, default is INFO. Change it to DEBUG
|
||||
to see more conversion details or to WARNING to see less
|
||||
|
||||
:returns: A TensorflowRep class object representing the ONNX model
|
||||
"""
|
||||
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
|
||||
common.logger.setLevel(logging_level)
|
||||
common.logger.handlers[0].setLevel(logging_level)
|
||||
|
||||
return cls.onnx_model_to_tensorflow_rep(model, strict)
|
||||
|
||||
@classmethod
|
||||
def onnx_model_to_tensorflow_rep(cls, model, strict):
|
||||
""" Convert ONNX model to TensorflowRep.
|
||||
|
||||
:param model: ONNX ModelProto object.
|
||||
:param strict: whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model.
|
||||
:return: TensorflowRep object.
|
||||
"""
|
||||
|
||||
# Models with IR_VERSION less than 3 does not have opset_import set.
|
||||
# We default to minimum opset, this behavior is consistent with
|
||||
# onnx checker.
|
||||
# c.f. https://github.com/onnx/onnx/blob/427ac0c1b792363d373e3d7e4eef97fa46458420/onnx/checker.cc#L478
|
||||
if model.ir_version < 3:
|
||||
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
|
||||
else:
|
||||
opset_import = model.opset_import
|
||||
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
|
||||
|
||||
@classmethod
|
||||
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
|
||||
""" Convert ONNX graph to TensorflowRep.
|
||||
|
||||
:param graph_def: ONNX GraphProto object.
|
||||
:param opset: ONNX OperatorSetIdProto list.
|
||||
:param strict: whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model.
|
||||
:return: TensorflowRep object.
|
||||
"""
|
||||
handlers = cls._get_handlers(opset)
|
||||
|
||||
tf_rep_graph = tf.Graph()
|
||||
with tf_rep_graph.as_default():
|
||||
# initializer: TensorProtos representing the values to initialize
|
||||
# a given tensor.
|
||||
# initialized: A list of names of the initialized tensors.
|
||||
if graph_def.initializer:
|
||||
input_dict_items = cls._onnx_initializer_to_input_dict_items(
|
||||
graph_def.initializer)
|
||||
initialized = {init.name for init in graph_def.initializer}
|
||||
else:
|
||||
input_dict_items = []
|
||||
initialized = set()
|
||||
|
||||
# creating placeholders for currently unknown inputs
|
||||
for value_info in graph_def.input:
|
||||
if value_info.name in initialized:
|
||||
continue
|
||||
shape = list(
|
||||
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
|
||||
for d in value_info.type.tensor_type.shape.dim)
|
||||
value_info_name = value_info.name.replace(
|
||||
":", "_tf_") + "_" + get_unique_suffix(
|
||||
) if ":" in value_info.name else value_info.name
|
||||
|
||||
x = tf.placeholder(data_type.onnx2tf(
|
||||
value_info.type.tensor_type.elem_type),
|
||||
name=value_info_name,
|
||||
shape=shape)
|
||||
input_dict_items.append((value_info.name, x))
|
||||
|
||||
# tensor dict: this dictionary is a map from variable names
|
||||
# to the latest produced TF tensors of the given name.
|
||||
# This dictionary will get updated as we build the graph to
|
||||
# record the names of newly produced tensors.
|
||||
tensor_dict = dict(input_dict_items)
|
||||
# Since tensor dict may be updated, we need to keep a copy
|
||||
# of the original input dict where we track the earliest
|
||||
# defined tensors so we can have access to the placeholders
|
||||
# to feed in input tensors when we run the graph.
|
||||
input_dict = dict(input_dict_items)
|
||||
|
||||
for node in graph_def.node:
|
||||
onnx_node = OnnxNode(node)
|
||||
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
|
||||
tensor_dict,
|
||||
handlers,
|
||||
opset=opset,
|
||||
strict=strict)
|
||||
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
|
||||
tensor_dict.update(curr_node_output_map)
|
||||
|
||||
tf_rep = TensorflowRep()
|
||||
tf_rep.graph = tf_rep_graph
|
||||
tf_rep.inputs = [
|
||||
value_info.name
|
||||
for value_info in graph_def.input
|
||||
if value_info.name not in initialized
|
||||
]
|
||||
tf_rep.outputs = [value_info.name for value_info in graph_def.output]
|
||||
tf_rep.tensor_dict = tensor_dict
|
||||
return tf_rep
|
||||
|
||||
@classmethod
|
||||
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
|
||||
""" Run ONNX node.
|
||||
|
||||
:param node: ONNX NodeProto object.
|
||||
:param inputs: Inputs.
|
||||
:param device: Device run on.
|
||||
:param outputs_info: None.
|
||||
:param kwargs: Other args.
|
||||
:return: Outputs.
|
||||
"""
|
||||
super(TensorflowBackend, cls).run_node(node, inputs, device)
|
||||
node_graph = tf.Graph()
|
||||
with node_graph.as_default():
|
||||
node = OnnxNode(node)
|
||||
device_option = get_device_option(Device(device))
|
||||
input_tensors = []
|
||||
for i in inputs:
|
||||
input_tensors.append(tf.constant(i))
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
feed_dict_raw = inputs
|
||||
else:
|
||||
assert len(node.inputs) == len(inputs)
|
||||
feed_dict_raw = dict(zip(node.inputs, inputs))
|
||||
|
||||
# TODO: is constant the best way for feeding inputs?
|
||||
input_dict = dict([
|
||||
(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
|
||||
])
|
||||
ops = cls._onnx_node_to_tensorflow_op(node, input_dict)
|
||||
|
||||
with tf.Session() as sess:
|
||||
with tf.device(device_option):
|
||||
sess.run(tf.global_variables_initializer())
|
||||
output_vals = sess.run(ops)
|
||||
|
||||
return namedtupledict('Outputs', node.outputs)(*output_vals)
|
||||
|
||||
@classmethod
|
||||
def _onnx_initializer_to_input_dict_items(cls, initializer):
|
||||
""" Convert ONNX graph initializer to input dict items.
|
||||
|
||||
:param initializer: ONNX graph initializer, list of TensorProto.
|
||||
:return: List of input dict items.
|
||||
"""
|
||||
|
||||
def tensor2list(onnx_tensor):
|
||||
# Use the onnx.numpy_helper because the data may be raw
|
||||
return numpy_helper.to_array(onnx_tensor).flatten().tolist()
|
||||
|
||||
def validate_initializer_name(name):
|
||||
# Prepend a unique suffix if leading charater is "_"
|
||||
name = get_unique_suffix() + name if name[0] is "_" else name
|
||||
|
||||
# Replace ":" with "_tf_" and append a unique suffix for
|
||||
# traceability
|
||||
return name.replace(
|
||||
":", "_tf_") + "_" + get_unique_suffix() if ":" in name else name
|
||||
|
||||
return [(init.name,
|
||||
tf.constant(tensor2list(init),
|
||||
shape=init.dims,
|
||||
dtype=data_type.onnx2tf(init.data_type),
|
||||
name=validate_initializer_name(init.name)))
|
||||
for init in initializer]
|
||||
|
||||
@classmethod
|
||||
def _onnx_node_to_tensorflow_op(cls,
|
||||
node,
|
||||
tensor_dict,
|
||||
handlers=None,
|
||||
opset=None,
|
||||
strict=True):
|
||||
"""
|
||||
Convert onnx node to tensorflow op.
|
||||
|
||||
Args:
|
||||
node: Onnx node object.
|
||||
tensor_dict: Tensor dict of graph.
|
||||
opset: Opset version of the operator set. Default 0 means using latest version.
|
||||
strict: whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
||||
Changing to False is strongly discouraged.
|
||||
Returns:
|
||||
Tensorflow op
|
||||
"""
|
||||
handlers = handlers or cls._get_handlers(opset)
|
||||
if handlers:
|
||||
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
|
||||
if handler:
|
||||
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
|
||||
|
||||
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
|
||||
|
||||
@classmethod
|
||||
def _get_handlers(cls, opset):
|
||||
""" Get all backend handlers with opset.
|
||||
|
||||
:param opset: ONNX OperatorSetIdProto list.
|
||||
:return: All backend handlers.
|
||||
"""
|
||||
opset = opset or [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
|
||||
opset_dict = dict([(o.domain, o.version) for o in opset])
|
||||
return get_all_backend_handlers(opset_dict)
|
||||
|
||||
@classmethod
|
||||
def supports_device(cls, device):
|
||||
return common_supports_device(device)
|
||||
|
||||
@classmethod
|
||||
def onnx_graph_to_tensorflow_ops(cls,
|
||||
subgraph,
|
||||
input_values,
|
||||
tensor_dict,
|
||||
opset=None,
|
||||
strict=True):
|
||||
"""
|
||||
Converts ONNX graph to Tensorflow operations
|
||||
Args:
|
||||
subgraph: the ONNX graph to be converted
|
||||
input_values: dictionary with values/tensors to initialize
|
||||
the subgraph inputs. if the subgraph.input
|
||||
are send in as parameters then it is required,
|
||||
otherwise this can be empty dictionary.
|
||||
tensor_dict: the dictionary that contain values for all the
|
||||
node.inputs in the subgraph that are not defined
|
||||
in the subgraph or input_values.
|
||||
opset: opset version of the operator set.
|
||||
strict: whether to enforce semantic equivalence between the
|
||||
original model and the converted tensorflow model,
|
||||
defaults to True (yes, enforce semantic equivalence).
|
||||
Returns:
|
||||
array of Tensorflow Tensors
|
||||
"""
|
||||
# get the subgraph.input from input_values
|
||||
subgraph_tensor_dict = input_values.copy()
|
||||
# get the rest of the subgraph input from tensor_dict
|
||||
for i in subgraph.input:
|
||||
if i.name not in subgraph_tensor_dict.keys():
|
||||
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
|
||||
# get the required initializer constant node(s) for the subgraph
|
||||
# Need to get the initializer constant nodes from tensor_dict here
|
||||
# because input from initializer will not be send in as inputs
|
||||
# to the subgraph and those nodes are not in the subgraph
|
||||
nodes_outputs = []
|
||||
for node in subgraph.node:
|
||||
for o_name in node.output:
|
||||
nodes_outputs.append(o_name)
|
||||
for node in subgraph.node:
|
||||
for i_name in node.input:
|
||||
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
|
||||
):
|
||||
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
|
||||
onnx_node = OnnxNode(node)
|
||||
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
|
||||
subgraph_tensor_dict,
|
||||
opset=opset,
|
||||
strict=strict)
|
||||
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
|
||||
subgraph_tensor_dict.update(curr_node_output_map)
|
||||
return subgraph_tensor_dict
|
||||
|
||||
@classmethod
|
||||
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
|
||||
"""
|
||||
Converts ONNX graph to TensorflowRep
|
||||
Args:
|
||||
graph_def: the ONNX graph to be converted
|
||||
strict: whether to enforce semantic equivalence between the
|
||||
original model and the converted tensorflow model,
|
||||
defaults to True (yes, enforce semantic equivalence).
|
||||
Returns:
|
||||
TensorflowRep object.
|
||||
"""
|
||||
# get the opset of the installed ONNX
|
||||
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
|
||||
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
|
||||
|
||||
|
||||
prepare = TensorflowBackend.prepare
|
||||
|
||||
run_node = TensorflowBackend.run_node
|
||||
|
||||
run_model = TensorflowBackend.run_model
|
||||
|
||||
supports_device = TensorflowBackend.supports_device
|
||||
|
||||
onnx_graph_to_tensorflow_ops = TensorflowBackend.onnx_graph_to_tensorflow_ops
|
||||
|
||||
onnx_graph_to_tensorflow_rep = TensorflowBackend.onnx_graph_to_tensorflow_rep
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx.backend.base import BackendRep, namedtupledict
|
||||
|
||||
|
||||
class TensorflowRep(BackendRep):
|
||||
|
||||
def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
|
||||
super(TensorflowRep, self).__init__()
|
||||
self._graph = graph
|
||||
self._inputs = inputs or []
|
||||
self._outputs = outputs or []
|
||||
self._tensor_dict = tensor_dict or {}
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self._graph
|
||||
|
||||
@graph.setter
|
||||
def graph(self, graph):
|
||||
self._graph = graph
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self._inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, inputs):
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return self._outputs
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, outputs):
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def tensor_dict(self):
|
||||
return self._tensor_dict
|
||||
|
||||
@tensor_dict.setter
|
||||
def tensor_dict(self, tensor_dict):
|
||||
self._tensor_dict = tensor_dict
|
||||
|
||||
def run(self, inputs, **kwargs):
|
||||
""" Run TensorflowRep.
|
||||
|
||||
:param inputs: Given inputs.
|
||||
:param kwargs: Other args.
|
||||
:return: Outputs.
|
||||
"""
|
||||
super(TensorflowRep, self).run(inputs, **kwargs)
|
||||
|
||||
# TODO: handle name scope if necessary
|
||||
with self.graph.as_default():
|
||||
with tf.Session() as sess:
|
||||
if isinstance(inputs, dict):
|
||||
feed_dict = inputs
|
||||
elif isinstance(inputs, list) or isinstance(inputs, tuple):
|
||||
if len(self.inputs) != len(inputs):
|
||||
raise RuntimeError('Expected {} values for uninitialized '
|
||||
'graph inputs ({}), but got {}.'.format(
|
||||
len(self.inputs), ', '.join(self.inputs),
|
||||
len(inputs)))
|
||||
feed_dict = dict(zip(self.inputs, inputs))
|
||||
else:
|
||||
# single input
|
||||
feed_dict = dict([(self.inputs[0], inputs)])
|
||||
|
||||
feed_dict = {
|
||||
self.tensor_dict[key]: feed_dict[key] for key in self.inputs
|
||||
}
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
outputs = [self.tensor_dict[output] for output in self.outputs]
|
||||
|
||||
output_values = sess.run(outputs, feed_dict=feed_dict)
|
||||
return namedtupledict('Outputs', self.outputs)(*output_values)
|
||||
|
||||
def export_graph(self, path):
|
||||
"""Export backend representation to a Tensorflow proto file.
|
||||
|
||||
This function obtains the graph proto corresponding to the ONNX
|
||||
model associated with the backend representation and serializes
|
||||
to a protobuf file.
|
||||
|
||||
:param path: The path to the output TF protobuf file.
|
||||
|
||||
:returns: none.
|
||||
"""
|
||||
graph_proto = self.graph.as_graph_def()
|
||||
# rename the output nodes
|
||||
meaningful_names = {}
|
||||
for output_name in self.outputs:
|
||||
meaningful_names[self.tensor_dict[output_name].name.replace(':0', '')] = output_name
|
||||
for node in graph_proto.node:
|
||||
if node.name in meaningful_names.keys():
|
||||
node.name = meaningful_names[node.name]
|
||||
|
||||
file = open(path, "wb")
|
||||
file.write(graph_proto.SerializeToString())
|
||||
file.close()
|
||||
@@ -0,0 +1,24 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import onnx_tf.converter
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ONNX-Tensorflow Command Line Interface")
|
||||
parser.add_argument(
|
||||
"command",
|
||||
choices=["convert"],
|
||||
help="Available commands.")
|
||||
|
||||
if len(args) == 0:
|
||||
parser.parse_args(["-h"])
|
||||
cli_tool = parser.parse_args([args[0]])
|
||||
if cli_tool.command == "convert":
|
||||
return onnx_tf.converter.main(args[1:])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
from onnx.backend.base import DeviceType
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
IS_PYTHON3 = sys.version_info > (3,)
|
||||
logger = logging.getLogger('onnx-tf')
|
||||
|
||||
# create console handler and formatter for logger
|
||||
console = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
console.setFormatter(formatter)
|
||||
logger.addHandler(console)
|
||||
|
||||
|
||||
class Deprecated:
|
||||
"""Add deprecated message when function is called.
|
||||
|
||||
Usage:
|
||||
from onnx_tf.common import deprecated
|
||||
|
||||
@deprecated
|
||||
def func():
|
||||
pass
|
||||
|
||||
UserWarning: func is deprecated. It will be removed in future release.
|
||||
|
||||
@deprecated("Message")
|
||||
def func():
|
||||
pass
|
||||
|
||||
UserWarning: Message
|
||||
|
||||
@deprecated({"arg": "Message",
|
||||
"arg_1": deprecated.MSG_WILL_REMOVE,
|
||||
"arg_2": "",})
|
||||
def func(arg, arg_1, arg_2):
|
||||
pass
|
||||
|
||||
UserWarning: Message
|
||||
UserWarning: arg_1 of func is deprecated. It will be removed in future release.
|
||||
UserWarning: arg_2 of func is deprecated.
|
||||
"""
|
||||
|
||||
MSG_WILL_REMOVE = " It will be removed in future release."
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.deprecated_decorator(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def messages():
|
||||
return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
|
||||
|
||||
@staticmethod
|
||||
def deprecated_decorator(arg=None):
|
||||
# deprecate function with default message MSG_WILL_REMOVE
|
||||
# @deprecated
|
||||
if inspect.isfunction(arg):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn("{} is deprecated.{}".format(
|
||||
arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
|
||||
return arg(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
|
||||
|
||||
def deco(func):
|
||||
# deprecate arg
|
||||
# @deprecated({...})
|
||||
if isinstance(deprecated_arg, dict):
|
||||
for name, message in deprecated_arg.items():
|
||||
if message in Deprecated.messages():
|
||||
message = "{} of {} is deprecated.{}".format(
|
||||
name, func.__module__ + "." + func.__name__, message or "")
|
||||
warnings.warn(message)
|
||||
# deprecate function with message
|
||||
# @deprecated("message")
|
||||
elif isinstance(deprecated_arg, str):
|
||||
message = deprecated_arg
|
||||
if message in Deprecated.messages():
|
||||
message = "{} is deprecated.{}".format(
|
||||
func.__module__ + "." + func.__name__, message)
|
||||
warnings.warn(message)
|
||||
return func
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
deprecated = Deprecated()
|
||||
|
||||
|
||||
# This function inserts an underscore before every upper
|
||||
# case letter and lowers that upper case letter except for
|
||||
# the first letter.
|
||||
def op_name_to_lower(name):
|
||||
return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
|
||||
|
||||
|
||||
def get_unique_suffix():
|
||||
""" Get unique suffix by using first 8 chars from uuid.uuid4
|
||||
to make unique identity name.
|
||||
|
||||
:return: Unique suffix string.
|
||||
"""
|
||||
return str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
def get_perm_from_formats(from_, to_):
|
||||
""" Get perm from data formats.
|
||||
For example:
|
||||
get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
|
||||
|
||||
:param from_: From data format string.
|
||||
:param to_: To data format string.
|
||||
:return: Perm. Int list.
|
||||
"""
|
||||
return list(map(lambda x: from_.find(x), to_))
|
||||
|
||||
|
||||
# TODO: allow more flexible placement
|
||||
def get_device_option(device):
|
||||
m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
|
||||
return m[device.type]
|
||||
|
||||
|
||||
def get_data_format(x_rank):
|
||||
""" Get data format by input rank.
|
||||
Channel first if support CUDA.
|
||||
|
||||
:param x_rank: Input rank.
|
||||
:return: Data format.
|
||||
"""
|
||||
sp_dim_names = ["D", "H", "W"]
|
||||
sp_dim_lst = []
|
||||
for i in range(x_rank - 2):
|
||||
sp_dim_lst.append(sp_dim_names[-i - 1])
|
||||
|
||||
sp_dim_string = "".join(reversed(sp_dim_lst))
|
||||
storage_format = "NC" + sp_dim_string
|
||||
|
||||
if supports_device("CUDA"):
|
||||
compute_format = "NC" + sp_dim_string
|
||||
else:
|
||||
compute_format = "N" + sp_dim_string + "C"
|
||||
return storage_format, compute_format
|
||||
|
||||
|
||||
def supports_device(device):
|
||||
""" Check if support target device.
|
||||
|
||||
:param device: CUDA or CPU.
|
||||
:return: If supports.
|
||||
"""
|
||||
if device == "CUDA":
|
||||
local_device_protos = device_lib.list_local_devices()
|
||||
return len([x.name for x in local_device_protos if x.device_type == 'GPU'
|
||||
]) > 0
|
||||
elif device == "CPU":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
|
||||
deprecated.MSG_WILL_REMOVE,
|
||||
"Use TensorflowGraph.get_outputs_names instead."))
|
||||
def get_output_node_names(graph_def):
|
||||
"""Get output node names from GraphDef.
|
||||
Args:
|
||||
graph_def: GraphDef object.
|
||||
Returns:
|
||||
List of output node names.
|
||||
"""
|
||||
nodes, input_names = dict(), set()
|
||||
for node in graph_def.node:
|
||||
nodes[node.name] = node
|
||||
input_names.update(set(node.input))
|
||||
return list(set(nodes) - input_names)
|
||||
|
||||
|
||||
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
|
||||
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
|
||||
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
|
||||
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from onnx_tf.common import IS_PYTHON3
|
||||
|
||||
|
||||
def convert_tf(attr):
|
||||
return __convert_tf_attr_value(attr)
|
||||
|
||||
|
||||
def convert_onnx(attr):
|
||||
return __convert_onnx_attribute_proto(attr)
|
||||
|
||||
|
||||
def __convert_tf_attr_value(attr):
|
||||
""" convert Tensorflow AttrValue object to Python object
|
||||
"""
|
||||
if attr.HasField('list'):
|
||||
return __convert_tf_list_value(attr.list)
|
||||
if attr.HasField('s'):
|
||||
return attr.s
|
||||
elif attr.HasField('i'):
|
||||
return attr.i
|
||||
elif attr.HasField('f'):
|
||||
return attr.f
|
||||
elif attr.HasField('b'):
|
||||
return attr.b
|
||||
elif attr.HasField('type'):
|
||||
return attr.type
|
||||
elif attr.HasField('shape'):
|
||||
return attr.type
|
||||
elif attr.HasField('tensor'):
|
||||
return attr.tensor
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
|
||||
|
||||
|
||||
def __convert_tf_list_value(list_value):
|
||||
""" convert Tensorflow ListValue object to Python object
|
||||
"""
|
||||
if list_value.s:
|
||||
return list_value.s
|
||||
elif list_value.i:
|
||||
return list_value.i
|
||||
elif list_value.f:
|
||||
return list_value.f
|
||||
elif list_value.b:
|
||||
return list_value.b
|
||||
elif list_value.tensor:
|
||||
return list_value.tensor
|
||||
elif list_value.type:
|
||||
return list_value.type
|
||||
elif list_value.shape:
|
||||
return list_value.shape
|
||||
elif list_value.func:
|
||||
return list_value.func
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
|
||||
|
||||
|
||||
def __convert_onnx_attribute_proto(attr_proto):
|
||||
"""
|
||||
Convert an ONNX AttributeProto into an appropriate Python object
|
||||
for the type.
|
||||
NB: Tensor attribute gets returned as the straight proto.
|
||||
"""
|
||||
if attr_proto.HasField('f'):
|
||||
return attr_proto.f
|
||||
elif attr_proto.HasField('i'):
|
||||
return attr_proto.i
|
||||
elif attr_proto.HasField('s'):
|
||||
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
|
||||
elif attr_proto.HasField('t'):
|
||||
return attr_proto.t # this is a proto!
|
||||
elif attr_proto.HasField('g'):
|
||||
return attr_proto.g
|
||||
elif attr_proto.floats:
|
||||
return list(attr_proto.floats)
|
||||
elif attr_proto.ints:
|
||||
return list(attr_proto.ints)
|
||||
elif attr_proto.strings:
|
||||
str_list = list(attr_proto.strings)
|
||||
if IS_PYTHON3:
|
||||
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
|
||||
return str_list
|
||||
elif attr_proto.HasField('sparse_tensor'):
|
||||
return attr_proto.sparse_tensor
|
||||
else:
|
||||
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
|
||||
@@ -0,0 +1,37 @@
|
||||
from tensorflow.python.framework.tensor_util import MakeNdarray
|
||||
|
||||
from onnx_tf.common import data_type
|
||||
|
||||
# Keyed by old attribute names.
|
||||
__tf_attr_translator = {
|
||||
"_output_shapes": lambda x: list(map(lambda shape: get_tf_shape_as_list(shape.dim), x.list.shape)),
|
||||
"shape": lambda x: get_tf_shape_as_list(x.shape.dim),
|
||||
"T": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"dtype": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"component_types": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"value": lambda x: MakeNdarray(x.tensor),
|
||||
"seed2": lambda x: float(x.i),
|
||||
"seed": lambda x: float(x.i),
|
||||
"keep_dims": lambda x: int(x.b),
|
||||
"squeeze_dims": lambda x: list(x.list.i),
|
||||
}
|
||||
|
||||
__onnx_attr_translator = {
|
||||
"axis": lambda x: int(x),
|
||||
"axes": lambda x: [int(a) for a in x],
|
||||
"dtype": lambda x: data_type.onnx2tf(x),
|
||||
"keepdims": lambda x: bool(x),
|
||||
"to": lambda x: data_type.onnx2tf(x),
|
||||
}
|
||||
|
||||
|
||||
def translate_tf(key, val):
|
||||
return __tf_attr_translator.get(key, lambda x: x)(val)
|
||||
|
||||
|
||||
def translate_onnx(key, val):
|
||||
return __onnx_attr_translator.get(key, lambda x: x)(val)
|
||||
|
||||
|
||||
def get_tf_shape_as_list(tf_shape_dim):
|
||||
return list(map(lambda x: x.size, list(tf_shape_dim)))
|
||||
@@ -0,0 +1,71 @@
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
from onnx import mapping
|
||||
from onnx import TensorProto
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def tf2onnx(dtype):
|
||||
if isinstance(dtype, Number):
|
||||
tf_dype = tf.as_dtype(dtype)
|
||||
elif isinstance(dtype, tf.DType):
|
||||
tf_dype = dtype
|
||||
elif isinstance(dtype, list):
|
||||
return [tf2onnx(t) for t in dtype]
|
||||
else:
|
||||
raise RuntimeError("dtype should be number or tf.DType.")
|
||||
|
||||
# Usually, tf2onnx is done via tf_type->numpy_type->onnx_type
|
||||
# to leverage existing type conversion infrastructure;
|
||||
# However, we need to intercept the string type early because
|
||||
# lowering tf.string type to numpy dtype results in loss of
|
||||
# information. <class 'object'> is returned instead of the
|
||||
# numpy string type desired.
|
||||
if tf_dype is tf.string:
|
||||
return TensorProto.STRING
|
||||
|
||||
onnx_dtype = None
|
||||
try:
|
||||
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
|
||||
tf_dype.as_numpy_dtype)]
|
||||
finally:
|
||||
if onnx_dtype is None:
|
||||
common.logger.warning(
|
||||
"Can't convert tf dtype {} to ONNX dtype. Return 0 (TensorProto.UNDEFINED)."
|
||||
.format(tf_dype))
|
||||
onnx_dtype = TensorProto.UNDEFINED
|
||||
return onnx_dtype
|
||||
|
||||
|
||||
def onnx2tf(dtype):
|
||||
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
|
||||
|
||||
|
||||
def onnx2field(dtype):
|
||||
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
|
||||
|
||||
|
||||
def _onnx_dtype(dtype):
|
||||
if isinstance(dtype, Number):
|
||||
onnx_dype = dtype
|
||||
elif isinstance(dtype, str):
|
||||
onnx_dype = TensorProto.DataType.Value(dtype)
|
||||
else:
|
||||
raise RuntimeError("dtype should be number or str.")
|
||||
return onnx_dype
|
||||
|
||||
|
||||
# TODO (tjingrant) unify _onnx_dtype into any_dtype_to_onnx_dtype
|
||||
def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
|
||||
dtype_mask = [1 if val else 0 for val in [np_dtype, tf_dtype, onnx_dtype]]
|
||||
num_type_set = sum(dtype_mask)
|
||||
assert num_type_set == 1, "One and only one type must be set. However, {} set.".format(
|
||||
sum(num_type_set))
|
||||
|
||||
if np_dtype:
|
||||
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
|
||||
if tf_dtype:
|
||||
onnx_dtype = tf2onnx(tf_dtype)
|
||||
|
||||
return onnx_dtype
|
||||
@@ -0,0 +1,73 @@
|
||||
import inspect
|
||||
import onnx_tf.common as common
|
||||
|
||||
|
||||
class CustomException(object):
|
||||
|
||||
def __init__(self):
|
||||
self._func = RuntimeError
|
||||
self._message = ""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if inspect.isclass(self._func) and issubclass(self._func, Exception):
|
||||
raise self._func(self.get_message(*args, **kwargs))
|
||||
elif callable(self._func):
|
||||
self._func(self.get_message(*args, **kwargs))
|
||||
|
||||
def get_message(self, *args, **kwargs):
|
||||
return self._message
|
||||
|
||||
|
||||
class OpUnimplementedException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnimplementedException, self).__init__()
|
||||
self._func = NotImplementedError
|
||||
self._message = "{} is not implemented."
|
||||
|
||||
def __call__(self, op, version=None, domain=None):
|
||||
if IGNORE_UNIMPLEMENTED:
|
||||
self._func = common.logger.warning
|
||||
super(OpUnimplementedException, self).__call__(op, version, domain)
|
||||
|
||||
def get_message(self, op, version=None, domain=None):
|
||||
insert_message = op
|
||||
if version is not None:
|
||||
insert_message += " version {}".format(version)
|
||||
if domain is not None:
|
||||
insert_message += " in domain `{}`".format(domain)
|
||||
return self._message.format(insert_message)
|
||||
|
||||
|
||||
class OpUnsupportedException(object):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnsupportedException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} is not supported in {}."
|
||||
|
||||
def __call__(self, op, framework):
|
||||
raise self._func(self.get_message(op, framework))
|
||||
|
||||
def get_message(self, op, framework):
|
||||
return self._message.format(op, framework)
|
||||
|
||||
|
||||
class ConstNotFoundException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(ConstNotFoundException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} of {} is not found in graph consts."
|
||||
|
||||
def __call__(self, name, op):
|
||||
super(ConstNotFoundException, self).__call__(name, op)
|
||||
|
||||
def get_message(self, name, op):
|
||||
return self._message.format(name, op)
|
||||
|
||||
|
||||
IGNORE_UNIMPLEMENTED = False
|
||||
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
|
||||
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
|
||||
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
|
||||
@@ -0,0 +1,76 @@
|
||||
from onnx import defs
|
||||
|
||||
import onnx_tf.common as common
|
||||
from onnx_tf.handlers.backend import * # noqa
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
|
||||
import onnx_tf.common as common
|
||||
|
||||
def get_all_backend_handlers(opset_dict):
|
||||
""" Get a dict of all backend handler classes.
|
||||
e.g. {'domain': {'Abs': Abs handler class}, ...}, }.
|
||||
|
||||
:param opset_dict: A dict of opset. e.g. {'domain': version, ...}
|
||||
:return: Dict.
|
||||
"""
|
||||
handlers = {}
|
||||
for handler in BackendHandler.__subclasses__():
|
||||
handler.check_cls()
|
||||
|
||||
domain = handler.DOMAIN
|
||||
version = opset_dict[domain] if domain in opset_dict else 1
|
||||
handler.VERSION = version
|
||||
|
||||
since_version = 1
|
||||
if defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
|
||||
try:
|
||||
since_version = defs.get_schema(
|
||||
handler.ONNX_OP,
|
||||
domain=handler.DOMAIN,
|
||||
max_inclusive_version=version).since_version
|
||||
except RuntimeError:
|
||||
common.logger.debug("Fail to get since_version of {} in domain `{}` "
|
||||
"with max_inclusive_version={}. Set to 1.".format(
|
||||
handler.ONNX_OP, handler.DOMAIN, version))
|
||||
else:
|
||||
common.logger.debug("Unknown op {} in domain `{}`.".format(
|
||||
handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
|
||||
handler.SINCE_VERSION = since_version
|
||||
handlers.setdefault(domain, {})[handler.ONNX_OP] = handler
|
||||
return handlers
|
||||
|
||||
|
||||
def get_backend_coverage():
|
||||
""" Get backend coverage for document.
|
||||
|
||||
:return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
|
||||
"""
|
||||
|
||||
onnx_coverage = {}
|
||||
experimental_op = set()
|
||||
for handler in BackendHandler.__subclasses__():
|
||||
handler.check_cls()
|
||||
|
||||
versions = handler.get_versions()
|
||||
domain = handler.DOMAIN
|
||||
if getattr(handler, "EXPERIMENTAL", False):
|
||||
experimental_op.add(handler.ONNX_OP)
|
||||
_update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
|
||||
return onnx_coverage, experimental_op
|
||||
|
||||
|
||||
def _update_coverage(coverage, domain, key, versions):
|
||||
domain_coverage = coverage.setdefault(domain, {})
|
||||
vers = domain_coverage.get(key, [])
|
||||
vers.extend(versions)
|
||||
domain_coverage[key] = sorted(list(set(vers)))
|
||||
|
||||
|
||||
def get_backend_partial_support_detail():
|
||||
ps_dict = {}
|
||||
opset_dict = dict([(defs.ONNX_DOMAIN, defs.onnx_opset_version())])
|
||||
handlers = get_all_backend_handlers(opset_dict)[defs.ONNX_DOMAIN]
|
||||
for op_name in handlers:
|
||||
if handlers[op_name].PARTIAL_SUPPORT:
|
||||
ps_dict[op_name] = handlers[op_name].PS_DESCRIPTION
|
||||
return ps_dict
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def get_onnx_version():
|
||||
return tuple(map(int, onnx.version.version.split(".")))
|
||||
|
||||
|
||||
# Returns whether onnx version is prior to major.minor.patch
|
||||
def legacy_onnx_pre_ver(major=0, minor=0, patch=0):
|
||||
return get_onnx_version() < (major, minor, patch)
|
||||
|
||||
|
||||
# Returns whether the opset version accompanying the
|
||||
# onnx installation is prior to version passed.
|
||||
def legacy_opset_pre_ver(version):
|
||||
return onnx.defs.onnx_opset_version() < version
|
||||
@@ -0,0 +1,263 @@
|
||||
from __future__ import division
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
pad_ops = namedtuple("pad_ops",
|
||||
["max_op", "ceil_op", "floor_op", "cast_int_op"])
|
||||
|
||||
pad_numpy_ops = pad_ops(np.maximum, np.ceil, np.floor,
|
||||
lambda arr: arr.astype(np.int64))
|
||||
pad_tf_ops = pad_ops(tf.maximum, tf.math.ceil, tf.math.floor,
|
||||
lambda tensor: tf.cast(tensor, tf.int64))
|
||||
|
||||
|
||||
def calc_pads_same(in_spatial_shape, kernel_shape, strides,
|
||||
dilations, padding, padding_ops=pad_numpy_ops,
|
||||
pads_order=1):
|
||||
"""
|
||||
Calculates the SAME paddings that need to be added to the input
|
||||
|
||||
Args:
|
||||
in_spatial_shape: input spatial shape
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis
|
||||
padding: padding to calculate: SAME_UPPER or
|
||||
SAME_LOWER
|
||||
padding_ops: namedtuple with ops to be used during
|
||||
calculations. there are two sets of ops
|
||||
defined pad_numpy_ops and pad_tf_ops with
|
||||
numpy and tensorflow ops
|
||||
pads_order: order of returned pads. possible options are:
|
||||
1 - b1, b2, ..., bn, e1, e2, ..., en
|
||||
2 - b1, e1, b2, e2, ..., bn, en
|
||||
where n = len(kernel_shape) * 2,
|
||||
b1, b2, ..., bn define pads at the begging of
|
||||
axis
|
||||
e1, e2, ..., en define pads at the end of
|
||||
axis
|
||||
Return:
|
||||
pads: array with calculated pads. the order of the
|
||||
values is determined by `pads_order`
|
||||
|
||||
"""
|
||||
spatial_size = len(kernel_shape)
|
||||
pads = [0] * (spatial_size * 2)
|
||||
for i in range(spatial_size):
|
||||
in_size = in_spatial_shape[i]
|
||||
filter_size = (kernel_shape[i] - 1) * dilations[i] + 1
|
||||
|
||||
out_size = padding_ops.ceil_op(in_size / strides[i])
|
||||
out_size = padding_ops.cast_int_op(out_size)
|
||||
pad_along_axis = \
|
||||
padding_ops.max_op((out_size - 1) * strides[i] +
|
||||
filter_size - in_size, 0)
|
||||
if padding.lower() == "same_lower":
|
||||
pad_op = padding_ops.ceil_op
|
||||
else:
|
||||
pad_op = padding_ops.floor_op
|
||||
pad_begin = pad_op(pad_along_axis / 2)
|
||||
|
||||
pad_begin = padding_ops.cast_int_op(pad_begin)
|
||||
pad_along_axis = padding_ops.cast_int_op(pad_along_axis)
|
||||
|
||||
pad_end = pad_along_axis - pad_begin
|
||||
|
||||
pads[i * pads_order] = pad_begin
|
||||
pads[i * pads_order +
|
||||
(spatial_size if pads_order == 1 else 1)] = pad_end
|
||||
|
||||
return pads
|
||||
|
||||
|
||||
def calc_output_shape(input_spatial_shape, kernel_shape, strides, dilations,
|
||||
padding, ceil_mode=False):
|
||||
"""
|
||||
Calculate output shape
|
||||
|
||||
Args:
|
||||
input_spatial_shape: input spatial shape
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis
|
||||
padding: can be explicit paddings, "SAME_UPPER" or
|
||||
"SAME_LOWER"
|
||||
Return:
|
||||
output_shape: calculated output shape
|
||||
"""
|
||||
spatial_size = len(input_spatial_shape)
|
||||
|
||||
if type(padding) is not list and type(padding) is not np.ndarray:
|
||||
if padding.lower().startswith("same"):
|
||||
padding = calc_pads_same(input_spatial_shape, kernel_shape,
|
||||
strides, dilations, padding)
|
||||
else:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
output_shape = []
|
||||
for dim in range(spatial_size):
|
||||
output_shape.append(_pooling_output_shape(input_spatial_shape[dim],
|
||||
kernel_shape[dim], strides[dim], dilations[dim],
|
||||
padding[dim] + padding[dim + spatial_size],
|
||||
ceil_mode))
|
||||
|
||||
return output_shape
|
||||
|
||||
|
||||
def _pooling_output_shape(input_size, ksize, stride, dilation, pad, ceil_mode):
|
||||
output_size = (input_size + pad - ((ksize - 1) * dilation + 1) +
|
||||
((stride-1) if ceil_mode else 0)) // stride + 1
|
||||
if (pad):
|
||||
if ((output_size - 1) * stride >= input_size + pad):
|
||||
output_size -= 1
|
||||
return output_size
|
||||
|
||||
|
||||
def py_pool(input, kernel_shape, strides=None, dilations=None,
|
||||
padding=None, ceil_mode=False, pooling_type="MAX",
|
||||
include_indices=True, p=2):
|
||||
"""
|
||||
Implementation of Max and Average pool operations in Python
|
||||
Args:
|
||||
input: input N-D data array in NC* format
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis of filter
|
||||
padding: padding for the beginning and ending along each
|
||||
spatial axis. `padding` format should be as follow
|
||||
[x1_begin, x2_begin...x1_end, x2_end,...]
|
||||
ceil_mode: whether to use ceil or floor (default) to compute
|
||||
the output shape.
|
||||
pooling_type: specifies pooling type. Values can be "MAX", "AVG" or
|
||||
"LP"
|
||||
include_indices: should indices be included in the output
|
||||
p: specifies the p parameter for LpPooling
|
||||
Return:
|
||||
pooled: output data from max pooling across the input
|
||||
ind: indices of the selected max values from the input
|
||||
"""
|
||||
|
||||
if type(pooling_type) is not str:
|
||||
pooling_type = pooling_type.decode("UTF-8")
|
||||
|
||||
input_shape = np.shape(input)
|
||||
inp_sp_shape = input_shape[2:]
|
||||
input_dtype = input.dtype
|
||||
if np.issubdtype(input_dtype, np.integer):
|
||||
input_dtype_min = np.iinfo(input_dtype).min
|
||||
else:
|
||||
input_dtype_min = np.finfo(input_dtype).min
|
||||
|
||||
if pooling_type == "LP":
|
||||
rootN = (1.0 / p)
|
||||
|
||||
def _loop_over_output(batch, channel):
|
||||
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
|
||||
for counters in itertools.product(*dims):
|
||||
input_ranges = []
|
||||
for dim in range(spatial_size):
|
||||
dim_start = \
|
||||
counters[dim] * strides[dim] - pads[dim * 2]
|
||||
dim_end = \
|
||||
min(dim_start + (kernel_shape[dim] - 1) * dilations[dim]
|
||||
+ 1, inp_sp_shape[dim])
|
||||
while dim_start < 0:
|
||||
dim_start += dilations[dim]
|
||||
|
||||
cur_range = [i for i in range(dim_start,
|
||||
dim_end, dilations[dim])]
|
||||
input_ranges.append(cur_range)
|
||||
if pooling_type in ["AVG", "LP"]:
|
||||
val_sum = 0
|
||||
val_count = 0
|
||||
else:
|
||||
maxval = input_dtype_min
|
||||
maxind = -1
|
||||
for input_ind in itertools.product(*input_ranges):
|
||||
ind = (batch, channel) + input_ind
|
||||
val = input[ind]
|
||||
if pooling_type == "AVG":
|
||||
val_sum += val
|
||||
val_count += 1
|
||||
elif pooling_type == "LP":
|
||||
val_sum += abs(val ** p)
|
||||
else:
|
||||
if val > maxval:
|
||||
maxval = val
|
||||
ind = 0
|
||||
for i in range(spatial_size):
|
||||
coef = 1
|
||||
for j in range(i+1, spatial_size):
|
||||
coef *= inp_sp_shape[j]
|
||||
ind += input_ind[i] * coef
|
||||
maxind = ind
|
||||
ind = (batch, channel) + counters
|
||||
if pooling_type == "AVG":
|
||||
out_pool[ind] = val_sum / val_count
|
||||
elif pooling_type == "LP":
|
||||
out_pool[ind] = val_sum ** rootN
|
||||
else:
|
||||
out_pool[ind] = maxval
|
||||
out_ind[ind] = maxind
|
||||
|
||||
spatial_size = len(kernel_shape)
|
||||
|
||||
batch_size = input_shape[0]
|
||||
channels_num = input_shape[1]
|
||||
|
||||
if strides is None:
|
||||
strides = kernel_shape
|
||||
|
||||
if dilations is None:
|
||||
dilations = [1] * spatial_size
|
||||
|
||||
if padding is None:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
if type(padding) is bytes:
|
||||
padding = padding.decode()
|
||||
|
||||
if type(padding) is not list and type(padding) is not np.ndarray:
|
||||
if type(padding) is not str:
|
||||
padding = padding.decode("UTF-8")
|
||||
|
||||
if padding.lower().startswith("same"):
|
||||
padding = calc_pads_same(inp_sp_shape, kernel_shape, strides,
|
||||
dilations, padding)
|
||||
else:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
pads = []
|
||||
pad_along_axis = []
|
||||
output_sp_shape = []
|
||||
|
||||
for dim in range(spatial_size):
|
||||
pads.append(padding[dim])
|
||||
pads.append(padding[dim + spatial_size])
|
||||
pad_along_axis.append(padding[dim] + padding[dim + spatial_size])
|
||||
|
||||
input_size = input_shape[dim + 2]
|
||||
output_size = \
|
||||
_pooling_output_shape(input_size, kernel_shape[dim],
|
||||
strides[dim], dilations[dim],
|
||||
pad_along_axis[dim], ceil_mode)
|
||||
output_sp_shape.append(output_size)
|
||||
|
||||
out_pool = np.zeros([input_shape[0], input_shape[1]] +
|
||||
output_sp_shape, input_dtype)
|
||||
out_ind = np.zeros([input_shape[0], input_shape[1]] +
|
||||
output_sp_shape, np.int64)
|
||||
|
||||
for batch in range(batch_size):
|
||||
for channel in range(channels_num):
|
||||
_loop_over_output(batch, channel)
|
||||
|
||||
if not include_indices:
|
||||
return out_pool
|
||||
else:
|
||||
return out_pool, out_ind
|
||||
@@ -0,0 +1,45 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tf_shape(tensor):
|
||||
"""
|
||||
Helper function returning the shape of a Tensor.
|
||||
The function will check for fully defined shape and will return
|
||||
numpy array or if the shape is not fully defined will use tf.shape()
|
||||
to return the shape as a Tensor.
|
||||
"""
|
||||
if tensor.shape.is_fully_defined():
|
||||
return np.array(tensor.shape.as_list(), dtype=np.int64)
|
||||
else:
|
||||
return tf.shape(tensor, out_type=tf.int64)
|
||||
|
||||
|
||||
def tf_product(a, b):
|
||||
"""
|
||||
Calculates the cartesian product of two column vectors a and b
|
||||
|
||||
Example:
|
||||
|
||||
a = [[1]
|
||||
[2]
|
||||
[3]]
|
||||
|
||||
b = [[0]
|
||||
[1]]
|
||||
|
||||
result = [[1 0]
|
||||
[1 1]
|
||||
[2 0]
|
||||
[2 1]
|
||||
[3 0]
|
||||
[3 1]]
|
||||
"""
|
||||
tile_a = tf.tile(a, [1, tf.shape(b)[0]])
|
||||
tile_a = tf.expand_dims(tile_a, 2)
|
||||
tile_a = tf.reshape(tile_a, [-1, 1])
|
||||
|
||||
b = tf.tile(b, [tf.shape(a)[0], 1])
|
||||
b = tf.concat([tile_a, b], axis=1)
|
||||
|
||||
return b
|
||||
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import onnx
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
|
||||
import onnx_tf.backend as backend
|
||||
import onnx_tf.common as common
|
||||
from onnx_tf.common import get_unique_suffix
|
||||
from onnx_tf.pb_wrapper import TensorflowGraph
|
||||
|
||||
|
||||
def main(args):
|
||||
args = parse_args(args)
|
||||
convert(**{k: v for k, v in vars(args).items() if v is not None})
|
||||
|
||||
|
||||
def parse_args(args):
|
||||
|
||||
class ListAction(argparse.Action):
|
||||
""" Define how to convert command line list strings to Python objects.
|
||||
"""
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
values = values if values[0] not in ("(", "[") or values[-1] not in (
|
||||
")", "]") else values[1:-1]
|
||||
res = []
|
||||
for value in values.split(","):
|
||||
if value.isdigit():
|
||||
res.append(int(value))
|
||||
else:
|
||||
res.append(value)
|
||||
setattr(namespace, self.dest, res)
|
||||
|
||||
class OpsetAction(argparse.Action):
|
||||
""" Define how to convert command line opset strings to Python objects.
|
||||
"""
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.isdigit():
|
||||
setattr(namespace, "opset", int(values))
|
||||
else:
|
||||
res = []
|
||||
while values and values[0] in ("(", "["):
|
||||
values = values[1:]
|
||||
while values and values[-1] in (")", "]"):
|
||||
values = values[:-1]
|
||||
for value in values.split("),("):
|
||||
l, r = value.split(",")
|
||||
res.append((l, int(r)))
|
||||
setattr(namespace, "opset", res)
|
||||
|
||||
def get_param_doc_dict(funcs):
|
||||
"""Get doc of funcs params.
|
||||
|
||||
Args:
|
||||
funcs: Target funcs.
|
||||
|
||||
Returns:
|
||||
Dict of params doc.
|
||||
"""
|
||||
|
||||
# TODO(fumihwh): support google doc format
|
||||
def helper(doc, func):
|
||||
first_idx = doc.find(":param")
|
||||
last_idx = doc.find(":return")
|
||||
last_idx = last_idx if last_idx != -1 else len(doc)
|
||||
param_doc = doc[first_idx:last_idx]
|
||||
params_doc = param_doc.split(":param ")[1:]
|
||||
return {
|
||||
p[:p.find(": ")]: p[p.find(": ") + len(": "):] +
|
||||
" (from {})".format(func.__module__ + "." + func.__name__)
|
||||
for p in params_doc
|
||||
}
|
||||
|
||||
param_doc_dict = {}
|
||||
for func, persists in funcs:
|
||||
doc = inspect.getdoc(func)
|
||||
doc_dict = helper(doc, func)
|
||||
for k, v in doc_dict.items():
|
||||
if k not in persists:
|
||||
continue
|
||||
param_doc_dict[k] = {"doc": v, "params": persists[k]}
|
||||
return param_doc_dict
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"This is the converter for converting protocol buffer between tf and onnx."
|
||||
)
|
||||
|
||||
# required two args, source and destination path
|
||||
parser.add_argument("--infile", "-i", help="Input file path.", required=True)
|
||||
parser.add_argument(
|
||||
"--outfile", "-o", help="Output file path.", required=True)
|
||||
|
||||
def add_argument_group(parser, group_name, funcs):
|
||||
group = parser.add_argument_group(group_name)
|
||||
param_doc_dict = get_param_doc_dict(funcs)
|
||||
for k, v in param_doc_dict.items():
|
||||
group.add_argument("--{}".format(k), help=v["doc"], **v["params"])
|
||||
|
||||
# backend args
|
||||
# Args must be named consistently with respect to backend.prepare.
|
||||
add_argument_group(parser, "backend arguments (onnx -> tf)",
|
||||
[(backend.prepare, {
|
||||
"device": {},
|
||||
"strict": {},
|
||||
"logging_level": {}
|
||||
})])
|
||||
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def convert(infile, outfile, **kwargs):
|
||||
"""Convert pb.
|
||||
|
||||
Args:
|
||||
infile: Input path.
|
||||
outfile: Output path.
|
||||
**kwargs: Other args for converting.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
logging_level = kwargs.get("logging_level", "INFO")
|
||||
common.logger.setLevel(logging_level)
|
||||
common.logger.handlers[0].setLevel(logging_level)
|
||||
|
||||
common.logger.info("Start converting onnx pb to tf pb:")
|
||||
onnx_model = onnx.load(infile)
|
||||
tf_rep = backend.prepare(onnx_model, **kwargs)
|
||||
tf_rep.export_graph(outfile)
|
||||
common.logger.info("Converting completes successfully.")
|
||||
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import onnx_tf.backend
|
||||
import onnx_tf.backend_rep
|
||||
from third_party import get_info
|
||||
|
||||
|
||||
def main(docs_dir):
|
||||
gen_api(docs_dir)
|
||||
gen_cli(docs_dir)
|
||||
|
||||
|
||||
def gen_api(docs_dir):
|
||||
gen_doc_for = {
|
||||
'onnx_tf.backend': [
|
||||
onnx_tf.backend.prepare,
|
||||
],
|
||||
'onnx_tf.backend_rep.TensorflowRep': [
|
||||
onnx_tf.backend_rep.TensorflowRep.export_graph,
|
||||
]
|
||||
}
|
||||
with open(os.path.join(docs_dir, 'API.md'), 'w') as doc_file:
|
||||
doc_file.write('ONNX-Tensorflow API\n')
|
||||
doc_file.write('======\n\n')
|
||||
|
||||
for scope, funcs in sorted(gen_doc_for.items()):
|
||||
for func in funcs:
|
||||
doc_parsed = get_info.parse_docstring(func.__doc__)
|
||||
doc_file.write('#### `' + scope + '.' + func.__name__ + '`\n\n')
|
||||
doc_file.write('<details>\n')
|
||||
doc_file.write(' <summary>')
|
||||
doc_file.write(doc_parsed['short_description'] + '\n\n')
|
||||
doc_file.write(' </summary>\n')
|
||||
doc_file.write(doc_parsed['long_description'] + '\n\n')
|
||||
doc_file.write('</details>\n\n\n\n')
|
||||
|
||||
doc_file.write('_params_:\n\n')
|
||||
for param in doc_parsed['params']:
|
||||
doc_file.write('`' + param['name'] + '` : ' + param['doc'] + '\n\n')
|
||||
|
||||
doc_file.write('_returns_:\n\n')
|
||||
doc_file.write(doc_parsed['returns'] + '\n\n')
|
||||
|
||||
|
||||
def gen_cli(docs_dir):
|
||||
with open(os.path.join(docs_dir, 'CLI_template.md'), 'r') as cli_temp_file:
|
||||
temp_lines = cli_temp_file.readlines()
|
||||
|
||||
lines = []
|
||||
for line in temp_lines:
|
||||
matched = re.match(r"{onnx-tf.*}", line)
|
||||
if matched:
|
||||
command = matched.string.strip()[1:-1]
|
||||
output = subprocess.check_output(command.split(" ")).decode("UTF-8")
|
||||
lines.append(output)
|
||||
else:
|
||||
lines.append(line)
|
||||
|
||||
with open(os.path.join(docs_dir, 'CLI.md'), 'w') as cli_file:
|
||||
cli_file.writelines(lines)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
docs_dir = os.path.join(base_dir, 'doc')
|
||||
main(docs_dir)
|
||||
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pprint
|
||||
|
||||
from onnx import defs
|
||||
|
||||
from onnx_tf.common.handler_helper import get_backend_coverage
|
||||
from onnx_tf.common.handler_helper import get_backend_partial_support_detail
|
||||
|
||||
|
||||
def main():
|
||||
backend_opset_dict = {}
|
||||
|
||||
for schema in defs.get_all_schemas():
|
||||
op_name = schema.name
|
||||
backend_opset_dict[op_name] = []
|
||||
|
||||
backend_onnx_coverage, backend_experimental_op = get_backend_coverage()
|
||||
backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {}))
|
||||
backend_ps_dict = get_backend_partial_support_detail()
|
||||
|
||||
with open('opset_version.py', 'w') as version_file:
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
version_file.write("backend_opset_version = {\n " +
|
||||
pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
|
||||
version_file.write("backend_partial_support = {\n " +
|
||||
pp.pformat(backend_ps_dict)[1:-1] + "\n}\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import getopt
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf import opset_version, __version__
|
||||
|
||||
|
||||
def main(docs_dir):
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
docs_dir = os.path.join(base_dir, 'doc')
|
||||
onnx_version = onnx.__version__
|
||||
onnx_tf_release_build = False
|
||||
|
||||
try:
|
||||
opts, args = getopt.getopt(sys.argv[1:], 'h:mr',
|
||||
['onnx_master', 'onnx_tf_release_build'])
|
||||
except getopt.GetoptError:
|
||||
print('Usage:')
|
||||
print(' gen_status.py [-m -r]')
|
||||
print(' gen_status.py -h')
|
||||
print('Description:')
|
||||
print(' -m, --onnx_master installed ONNX is the latest master code')
|
||||
print(' if omitted, ONNX version is onnx.__version__')
|
||||
print(' -r, --onnx_tf_release_build create report for ONNX-TF release with version')
|
||||
print(' stated in the VERSION_NUMBER file')
|
||||
print(' if omitted, the report is for ONNX-TF master')
|
||||
print(' -h show this help message and exit')
|
||||
print('eg. 1. generate support_status.md for ONNX-TF master and ONNX onnx.__version__')
|
||||
print(' gen_status.py')
|
||||
print(' 2. generate support_status.md for ONNX-TF master and ONNX master')
|
||||
print(' gen_status.py -m')
|
||||
print(' 3. generate support_status_<onnx_tf_version>.md for ONNX-TF version')
|
||||
print(' stated in the VERSION_NUMBER file and ONNX onnx.__version__ ')
|
||||
print(' gen_status.py -r')
|
||||
sys.exit(2)
|
||||
for opt, arg in opts:
|
||||
if opt == '-h':
|
||||
print('Usage:')
|
||||
print(' gen_status.py [-m -r]')
|
||||
print(' gen_status.py -h')
|
||||
print('Description:')
|
||||
print(' -m, --onnx_master installed ONNX is the latest master code')
|
||||
print(' if omitted, ONNX version is onnx.__version__')
|
||||
print(' -r, --onnx_tf_release_build create report for ONNX-TF release with version')
|
||||
print(' stated in the VERSION_NUMBER file')
|
||||
print(' if omitted, the report is for ONNX-TF master')
|
||||
print(' -h show this help message and exit')
|
||||
print('eg. 1. generate support_status.md for ONNX-TF master and ONNX onnx.__version__')
|
||||
print(' gen_status.py')
|
||||
print(' 2. generate support_status.md for ONNX-TF master and ONNX master')
|
||||
print(' gen_status.py -m')
|
||||
print(' 3. generate support_status_<onnx_tf_version>.md for ONNX-TF version')
|
||||
print(' stated in the VERSION_NUMBER file and ONNX onnx.__version__ ')
|
||||
print(' gen_status.py -r')
|
||||
sys.exit()
|
||||
elif opt in ('-m', '--onnx_master'):
|
||||
onnx_version = 'master'
|
||||
elif opt in ('-r', '--onnx_tf_release_build'):
|
||||
onnx_tf_release_build = True
|
||||
|
||||
gen_support_status(docs_dir, onnx_version, onnx_tf_release_build)
|
||||
|
||||
|
||||
def gen_support_status(docs_dir, onnx_version, onnx_tf_release_build):
|
||||
|
||||
# set filename
|
||||
if onnx_tf_release_build:
|
||||
onnx_tf_version = 'v' + __version__
|
||||
filename = 'support_status_' + onnx_tf_version.replace('.', '_') + '.md'
|
||||
else: # onnx-tf = master
|
||||
# get onnx-tf commit id
|
||||
onnx_tf_commit_id = subprocess.check_output('git rev-parse HEAD',
|
||||
shell=True)
|
||||
onnx_tf_commit_id = onnx_tf_commit_id.decode().strip('\n')
|
||||
onnx_tf_version = 'Master ( commit id: {} )'.format(onnx_tf_commit_id)
|
||||
filename = 'support_status.md'
|
||||
|
||||
with open(os.path.join(docs_dir, filename), 'w') as status_file:
|
||||
status_file.write('# ONNX-Tensorflow Support Status\n')
|
||||
status_file.write('|||\n')
|
||||
status_file.write('|-:|:-|\n')
|
||||
status_file.write('|ONNX-Tensorflow Version|{}|\n'.format(onnx_tf_version))
|
||||
|
||||
# get onnx commit id
|
||||
if onnx_version == 'master':
|
||||
onnx_commit_id = onnx.version.git_version
|
||||
status_file.write(
|
||||
'|ONNX Version|Master ( commit id: {} )|\n'.format(onnx_commit_id))
|
||||
else:
|
||||
status_file.write('|ONNX Version|v{}|\n'.format(onnx_version))
|
||||
|
||||
# get tf_version
|
||||
status_file.write('|Tensorflow Version|v{}|\n\n'.format(tf.__version__))
|
||||
|
||||
# display the table legend
|
||||
status_file.write('Notes:\n')
|
||||
status_file.write('* Values that are new or updated from a ')
|
||||
status_file.write('previous opset version are in bold.\n')
|
||||
status_file.write('* -: not defined in corresponding ONNX ')
|
||||
status_file.write('opset version\n')
|
||||
status_file.write('* \*: the operator is deprecated\n')
|
||||
status_file.write('* :small_red_triangle:: not supported yet\n')
|
||||
status_file.write('* :small_orange_diamond:: partially supported\n')
|
||||
status_file.write('* the rest are all supported\n\n')
|
||||
|
||||
# get oll onnx ops
|
||||
onnx_ops = {}
|
||||
for schema in onnx.defs.get_all_schemas():
|
||||
if schema.domain == '': # only get onnx ops
|
||||
onnx_ops[schema.name] = {
|
||||
'versions': [],
|
||||
'deprecated': schema.since_version if schema.deprecated else -1
|
||||
}
|
||||
for schema in onnx.defs.get_all_schemas_with_history():
|
||||
if schema.domain == '': # only get onnx ops
|
||||
op = onnx_ops[schema.name]
|
||||
if schema.deprecated:
|
||||
if schema.since_version <= op['deprecated']:
|
||||
op['versions'].append(schema.since_version)
|
||||
op['deprecated'] = schema.since_version
|
||||
else:
|
||||
op['versions'].append(schema.since_version)
|
||||
|
||||
# get all onnx-tf supported ops
|
||||
onnx_tf_ops = opset_version.backend_opset_version
|
||||
onnx_tf_ops_ps = opset_version.backend_partial_support
|
||||
|
||||
# get the cureent opset version
|
||||
current_opset = onnx.defs.onnx_opset_version()
|
||||
|
||||
# setup table header
|
||||
status_file.write('|||')
|
||||
for i in range(current_opset):
|
||||
status_file.write('|')
|
||||
status_file.write('\n|:-:|:-:|')
|
||||
for i in range(current_opset):
|
||||
status_file.write(':-:|')
|
||||
status_file.write('\n|**ONNX Operator**|')
|
||||
for opset in range(1, current_opset + 1):
|
||||
status_file.write('**Opset {}**|'.format(opset))
|
||||
status_file.write('**ONNX Operator**|')
|
||||
|
||||
ops_count = len(onnx_ops)
|
||||
# fill in data for the table
|
||||
for key, val in sorted(onnx_ops.items()):
|
||||
try:
|
||||
status_file.write('\n|{}|'.format(key))
|
||||
i = 0
|
||||
vers = val['versions']
|
||||
deprecated = val['deprecated']
|
||||
for opset in range(1, current_opset + 1):
|
||||
if i <= len(vers) - 1:
|
||||
lb = vers[i]
|
||||
ub = vers[i + 1] if i < len(vers) - 1 else vers[i]
|
||||
if opset < lb:
|
||||
if i == 0:
|
||||
status_file.write('-')
|
||||
elif opset == lb:
|
||||
status_file.write('**{}**'.format(lb))
|
||||
if lb >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif lb not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
else: # opset > lb
|
||||
if opset < ub:
|
||||
status_file.write('{}'.format(lb))
|
||||
if lb >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif lb not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
elif opset == ub:
|
||||
status_file.write('**{}**'.format(ub))
|
||||
if ub >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif ub not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
i += 1
|
||||
else: #opset > ub
|
||||
status_file.write('{}'.format(ub))
|
||||
if ub >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif ub not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
status_file.write('|')
|
||||
status_file.write('{}|'.format(key))
|
||||
except:
|
||||
# ops defined in onnx but not in opset_version.backend_opset_versionn
|
||||
status_file.write(':small_red_triangle:|')
|
||||
|
||||
status_file.write(
|
||||
'\n\nONNX-TF Supported Operators / ONNX Operators: {} / {}'.format(
|
||||
ops_count, len(onnx_ops)))
|
||||
|
||||
# display partial support footnote
|
||||
status_file.write('\n\nNotes:\n')
|
||||
index = 1
|
||||
for key in onnx_tf_ops_ps:
|
||||
status_file.write(
|
||||
str(index) + '. ' + key + ': ' + onnx_tf_ops_ps[key] + '\n')
|
||||
index += 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,7 @@
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
__all__ = [
|
||||
modname for _, modname, _ in pkgutil.walk_packages(
|
||||
path=[os.path.split(__file__)[0]])
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user