1 回答
![?](http://img1.sycdn.imooc.com/545867280001ed6402200220-100-100.jpg)
TA贡献1856条经验 获得超11个赞
对于其他寻找答案的人来说,这就是我最终想出的:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import time
class CustLayer(keras.layers.Layer):
def __init__(self, info_matrix, first_dim, second_dim, info_dim, batch_size):
super(CustLayer, self).__init__()
self.w = tf.Variable(
initial_value=info_matrix,
trainable=False,
dtype=tf.dtypes.float32
)
self.info_matrix = info_matrix
self.info_dim = info_dim
self.first_dim = first_dim
self.second_dim = second_dim
self.batch_size = batch_size
def my_numpy_func(self, x):
# x will be a numpy array with the contents of the input to the
# tf.function
shape = x.shape
goal_arr = np.zeros(shape=(shape[0], shape[1], shape[2], self.info_dim), dtype=np.float32)
# indices to expand
idx = np.indices(shape)
goal_arr[idx[0], idx[1], idx[2]] = self.info_matrix[x[idx[0], idx[1], idx[2]]]
shape_arr = np.array([shape[0], shape[1], shape[2]], dtype=np.int8)
#tf.print("Shape:", shape)
#tf.print("Shape_arr:", shape_arr)
#tf.print("Type:",type(shape_arr))
return goal_arr, shape_arr
@tf.function(input_signature=[tf.TensorSpec((None, 39, 25), tf.int64)])
def tf_function(self, input):
y, shape_arr = tf.numpy_function(self.my_numpy_func, [input], [tf.float32, tf.int8], "Nameless")
#tf.print("shape_arr", shape_arr)
y = tf.reshape(y, shape=(shape_arr[0], shape_arr[1], shape_arr[2], self.info_dim))
return y
def call(self, orig_arr):
return self.tf_function(orig_arr)
注意事项:在 GPU 上运行,但不能在 TPU 上运行。
添加回答
举报