在tensorflow和pytorch中都有一个gather函数,其作用相似但是用法不同。关于tf.gather的用法可以参考知乎作者Towser的文章《TF 中的 indexing 和 slicing》。torch.gather函数的用法也很简单,就是给定indices获取tensor对应元素。给个例子就明白了。

1
2
3
4
5
6
tensor = torch.Tensor([[1,2,3],[4,5,6]])
indexs = torch.LongTensor([[0,1],[1,2]]) #就是获取tensor中的[1,0],[1,1],[2,1],[2,2]对应位置的元素
print(torch.gather(tensor, 1, indexs))

#tensor([[1., 2.],
# [5., 6.]])

下面直接给出torch.gather的tf实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def torch_gather_2d(input, indices):
len_1d, len_2d = tf.shape(input)[0], tf.shape(input)[1]
idx_matrix = tf.tile(tf.expand_dims(tf.range(0, len_2d), 0), [len_1d,1])

indices_t = tf.transpose(indices)
len = indices_t.get_shape()[0]

for i in range(len):
coln = tf.nn.embedding_lookup(indices_t, [i])
idx_mask_new = tf.equal(idx_matrix, tf.transpose(coln))

if i == 0:
idx_mask = idx_mask_new
idx_mask = tf.logical_or(idx_mask_new, idx_mask)

input = tf.reshape(tf.boolean_mask(input, idx_mask), tf.shape(indices))
return input

with tf.Session() as sess:
x = tf.reshape(tf.range(1,49), [6,8])
y = tf.constant([[0,1], [1,2], [2,3,], [3,4], [4,5], [5,6]])
out = torch_gather_2d(x, y)

print("x:\n", sess.run(x))
print("y:\n", sess.run(y))
print("out:\n", sess.run(out))

# x:
# [[ 1 2 3 4 5 6 7 8]
# [ 9 10 11 12 13 14 15 16]
# [17 18 19 20 21 22 23 24]
# [25 26 27 28 29 30 31 32]
# [33 34 35 36 37 38 39 40]
# [41 42 43 44 45 46 47 48]]
# y:
# [[0 1]
# [1 2]
# [2 3]
# [3 4]
# [4 5]
# [5 6]]
# out:
# [[ 1 2]
# [10 11]
# [19 20]
# [28 29]
# [37 38]
# [46 47]]

就酱。还是写tf太少了,不过mask真的是一种很有用的技巧。