1 回答

TA贡献1824条经验 获得超6个赞
不是使用稀疏张量来制作“除 softmaxed top-K 值之外的所有零”的张量,而是使用tf.scatter_nd:
import tensorflow as tf
def softmax_top_k(logits, k=10):
values, indices = tf.nn.top_k(logits, k, sorted=False)
softmax = tf.nn.softmax(values)
logits_shape = tf.shape(logits)
# Assuming that logits is 2D
rows = tf.tile(tf.expand_dims(tf.range(logits_shape[0]), 1), [1, k])
scatter_idx = tf.stack([rows, indices], axis=-1)
return tf.scatter_nd(scatter_idx, softmax, logits_shape)
编辑:这是具有任意维数的张量的稍微复杂的版本。但是,代码仍然要求在图构建时知道维数。
import tensorflow as tf
def softmax_top_k(logits, k=10):
values, indices = tf.nn.top_k(logits, k, sorted=False)
softmax = tf.nn.softmax(values)
# Make nd indices
logits_shape = tf.shape(logits)
dims = [tf.range(logits_shape[i]) for i in range(logits_shape.shape.num_elements() - 1)]
grid = tf.meshgrid(*dims, tf.range(k), indexing='ij')
scatter_idx = tf.stack(grid[:-1] + [indices], axis=-1)
return tf.scatter_nd(scatter_idx, softmax, logits_shape)
添加回答
举报