87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
from functools import partial
|
|
|
|
import tensorflow as tf
|
|
# import tensorflow_probability as tfp
|
|
from tensorflow.python.ops import array_ops
|
|
|
|
from onnx_tf.common import exception
|
|
|
|
|
|
class RNNMixin(object):
|
|
|
|
ONNX_ACTIVATION_MAPPING = {
|
|
# Added from tf 1.8
|
|
# "affine": tf.contrib.distributions.bijectors.AffineScalar,
|
|
# tf.contrib was removed since tf 2.0,
|
|
# Class Affine had been move to the following module
|
|
# "affine": tfp.bijectors.Affine,
|
|
"elu": tf.nn.elu,
|
|
"hard_sigmoid": tf.keras.backend.hard_sigmoid,
|
|
"leaky_relu": tf.nn.leaky_relu,
|
|
"relu": tf.nn.relu,
|
|
"sigmoid": tf.sigmoid,
|
|
"softsign": tf.nn.softsign,
|
|
"softplus": tf.nn.softplus,
|
|
"tanh": tf.tanh,
|
|
"thresholded_relu": tf.keras.layers.ThresholdedReLU,
|
|
}
|
|
|
|
@classmethod
|
|
def rnn(cls, x, cell_class, cell_kwargs, rnn_kwargs, activations, direction):
|
|
cell_kwargs["activation"] = activations[0]
|
|
|
|
rnn_cell = [cell_class(**cell_kwargs)]
|
|
cell_fw = tf.nn.rnn_cell.MultiRNNCell(rnn_cell)
|
|
|
|
if direction == "bidirectional":
|
|
cell_kwargs["activation"] = activations[1]
|
|
rnn_cell_bw = [cell_class(**cell_kwargs)]
|
|
cell_bw = tf.nn.rnn_cell.MultiRNNCell(rnn_cell_bw)
|
|
|
|
if direction == "forward":
|
|
outputs, states = tf.nn.dynamic_rnn(cell_fw, x, **rnn_kwargs)
|
|
elif direction == "bidirectional":
|
|
outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x,
|
|
**rnn_kwargs)
|
|
elif direction == "reverse":
|
|
|
|
def _reverse(input_, seq_dim):
|
|
return array_ops.reverse(input_, axis=[seq_dim])
|
|
|
|
time_dim = 0
|
|
inputs_reverse = _reverse(x, time_dim)
|
|
outputs, states = tf.nn.dynamic_rnn(cell_fw, inputs_reverse, **rnn_kwargs)
|
|
outputs = _reverse(outputs, time_dim)
|
|
|
|
return outputs, states
|
|
|
|
@classmethod
|
|
def rnn_get_activation(cls, name, alpha, beta):
|
|
if name not in cls.ONNX_ACTIVATION_MAPPING:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
"Activation function {} for {}".format(name, cls.__name__),
|
|
"Tensorflow")
|
|
activation = cls.ONNX_ACTIVATION_MAPPING[name]
|
|
kwargs = {}
|
|
if name == "affine":
|
|
kwargs["scale"] = alpha
|
|
kwargs["shift"] = beta
|
|
activation = activation(**kwargs)
|
|
elif name == "elu":
|
|
if alpha != 1:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
"Activation function {} with alpha={} for {}".format(
|
|
name, alpha, cls.__name__), "Tensorflow")
|
|
elif name == "hard_sigmoid":
|
|
if alpha != 0.2 or beta != 0.5:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
"Activation function {} with alpha={}, beta={} for {}".format(
|
|
name, alpha, beta, cls.__name__), "Tensorflow")
|
|
elif name == "leaky_relu":
|
|
kwargs["alpha"] = alpha or 0.01
|
|
activation = partial(activation, **kwargs)
|
|
elif name == "thresholded_relu":
|
|
kwargs["theta"] = alpha
|
|
activation = activation(**kwargs)
|
|
return activation
|