28 lines
831 B
Python
28 lines
831 B
Python
import tensorflow as tf
|
|
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
from onnx_tf.handlers.handler import tf_func
|
|
|
|
@onnx_op("IsInf")
|
|
@tf_func(tf.math.is_inf)
|
|
class IsInf(BackendHandler):
|
|
|
|
@classmethod
|
|
def version_10(cls, node, **kwargs):
|
|
inp = kwargs["tensor_dict"][node.inputs[0]]
|
|
dtype = inp.dtype
|
|
shape = inp.shape
|
|
zero = tf.zeros(shape, dtype)
|
|
dn = node.attrs.get("detect_negative", 1)
|
|
dp = node.attrs.get("detect_positive", 1)
|
|
# detecting only positive infinity, zero out elements < 0
|
|
if dn == 0:
|
|
inp = tf.maximum(zero, inp)
|
|
# detecting only negative infinity, zero out elements > 0
|
|
if dp == 0:
|
|
inp = tf.minimum(zero, inp)
|
|
return [
|
|
cls.make_tensor_from_onnx_node(node, inputs=[inp], **kwargs)
|
|
]
|