add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,17 @@
import numpy as np
import tensorflow as tf
class PadMixin(object):
@classmethod
def get_padding_as_op(cls, x, pads):
num_dim = int(len(pads) / 2)
tf_pads = np.transpose(np.array(pads).reshape([2, num_dim]))
tf_pads = [0, 0, 0, 0] + tf_pads.flatten().tolist()
padding = tf.constant(
np.array(tf_pads).reshape([num_dim + 2, 2])
.astype(np.int32)) # tf requires int32 paddings
return tf.pad(x, padding)