重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from onnx_tf.handlers.handler import tf_func
|
||||
|
||||
@onnx_op("IsInf")
|
||||
@tf_func(tf.math.is_inf)
|
||||
class IsInf(BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def version_10(cls, node, **kwargs):
|
||||
inp = kwargs["tensor_dict"][node.inputs[0]]
|
||||
dtype = inp.dtype
|
||||
shape = inp.shape
|
||||
zero = tf.zeros(shape, dtype)
|
||||
dn = node.attrs.get("detect_negative", 1)
|
||||
dp = node.attrs.get("detect_positive", 1)
|
||||
# detecting only positive infinity, zero out elements < 0
|
||||
if dn == 0:
|
||||
inp = tf.maximum(zero, inp)
|
||||
# detecting only negative infinity, zero out elements > 0
|
||||
if dp == 0:
|
||||
inp = tf.minimum(zero, inp)
|
||||
return [
|
||||
cls.make_tensor_from_onnx_node(node, inputs=[inp], **kwargs)
|
||||
]
|
||||
Reference in New Issue
Block a user