1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| def torch_gather_2d(input, indices): len_1d, len_2d = tf.shape(input)[0], tf.shape(input)[1] idx_matrix = tf.tile(tf.expand_dims(tf.range(0, len_2d), 0), [len_1d,1])
indices_t = tf.transpose(indices) len = indices_t.get_shape()[0]
for i in range(len): coln = tf.nn.embedding_lookup(indices_t, [i]) idx_mask_new = tf.equal(idx_matrix, tf.transpose(coln))
if i == 0: idx_mask = idx_mask_new idx_mask = tf.logical_or(idx_mask_new, idx_mask)
input = tf.reshape(tf.boolean_mask(input, idx_mask), tf.shape(indices)) return input
with tf.Session() as sess: x = tf.reshape(tf.range(1,49), [6,8]) y = tf.constant([[0,1], [1,2], [2,3,], [3,4], [4,5], [5,6]]) out = torch_gather_2d(x, y)
print("x:\n", sess.run(x)) print("y:\n", sess.run(y)) print("out:\n", sess.run(out))
|