Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/handlers/backend/rnn_mixin.py
T
2020-10-14 08:55:07 +08:00

87 lines
2.9 KiB
Python

from functools import partial
import tensorflow as tf
# import tensorflow_probability as tfp
from tensorflow.python.ops import array_ops
from onnx_tf.common import exception
class RNNMixin(object):
ONNX_ACTIVATION_MAPPING = {
# Added from tf 1.8
# "affine": tf.contrib.distributions.bijectors.AffineScalar,
# tf.contrib was removed since tf 2.0,
# Class Affine had been move to the following module
# "affine": tfp.bijectors.Affine,
"elu": tf.nn.elu,
"hard_sigmoid": tf.keras.backend.hard_sigmoid,
"leaky_relu": tf.nn.leaky_relu,
"relu": tf.nn.relu,
"sigmoid": tf.sigmoid,
"softsign": tf.nn.softsign,
"softplus": tf.nn.softplus,
"tanh": tf.tanh,
"thresholded_relu": tf.keras.layers.ThresholdedReLU,
}
@classmethod
def rnn(cls, x, cell_class, cell_kwargs, rnn_kwargs, activations, direction):
cell_kwargs["activation"] = activations[0]
rnn_cell = [cell_class(**cell_kwargs)]
cell_fw = tf.nn.rnn_cell.MultiRNNCell(rnn_cell)
if direction == "bidirectional":
cell_kwargs["activation"] = activations[1]
rnn_cell_bw = [cell_class(**cell_kwargs)]
cell_bw = tf.nn.rnn_cell.MultiRNNCell(rnn_cell_bw)
if direction == "forward":
outputs, states = tf.nn.dynamic_rnn(cell_fw, x, **rnn_kwargs)
elif direction == "bidirectional":
outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x,
**rnn_kwargs)
elif direction == "reverse":
def _reverse(input_, seq_dim):
return array_ops.reverse(input_, axis=[seq_dim])
time_dim = 0
inputs_reverse = _reverse(x, time_dim)
outputs, states = tf.nn.dynamic_rnn(cell_fw, inputs_reverse, **rnn_kwargs)
outputs = _reverse(outputs, time_dim)
return outputs, states
@classmethod
def rnn_get_activation(cls, name, alpha, beta):
if name not in cls.ONNX_ACTIVATION_MAPPING:
exception.OP_UNSUPPORTED_EXCEPT(
"Activation function {} for {}".format(name, cls.__name__),
"Tensorflow")
activation = cls.ONNX_ACTIVATION_MAPPING[name]
kwargs = {}
if name == "affine":
kwargs["scale"] = alpha
kwargs["shift"] = beta
activation = activation(**kwargs)
elif name == "elu":
if alpha != 1:
exception.OP_UNSUPPORTED_EXCEPT(
"Activation function {} with alpha={} for {}".format(
name, alpha, cls.__name__), "Tensorflow")
elif name == "hard_sigmoid":
if alpha != 0.2 or beta != 0.5:
exception.OP_UNSUPPORTED_EXCEPT(
"Activation function {} with alpha={}, beta={} for {}".format(
name, alpha, beta, cls.__name__), "Tensorflow")
elif name == "leaky_relu":
kwargs["alpha"] = alpha or 0.01
activation = partial(activation, **kwargs)
elif name == "thresholded_relu":
kwargs["theta"] = alpha
activation = activation(**kwargs)
return activation