重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PadMixin(object):
|
||||
|
||||
@classmethod
|
||||
def get_padding_as_op(cls, x, pads):
|
||||
num_dim = int(len(pads) / 2)
|
||||
|
||||
tf_pads = np.transpose(np.array(pads).reshape([2, num_dim]))
|
||||
tf_pads = [0, 0, 0, 0] + tf_pads.flatten().tolist()
|
||||
|
||||
padding = tf.constant(
|
||||
np.array(tf_pads).reshape([num_dim + 2, 2])
|
||||
.astype(np.int32)) # tf requires int32 paddings
|
||||
return tf.pad(x, padding)
|
||||
Reference in New Issue
Block a user