Files
2020-10-14 08:55:07 +08:00

18 lines
446 B
Python

import numpy as np
import tensorflow as tf
class PadMixin(object):
@classmethod
def get_padding_as_op(cls, x, pads):
num_dim = int(len(pads) / 2)
tf_pads = np.transpose(np.array(pads).reshape([2, num_dim]))
tf_pads = [0, 0, 0, 0] + tf_pads.flatten().tolist()
padding = tf.constant(
np.array(tf_pads).reshape([num_dim + 2, 2])
.astype(np.int32)) # tf requires int32 paddings
return tf.pad(x, padding)