在本地使用了tf.compat.v1.scatter_update()更新tensor的值,测试OK放到服务器上跑一直失败,报错LookupError: No gradient defined for operation 'ScatterUpdate' (op type: ScatterUpdate),最后发现疑似tf1.15的tf.scatter_update()没有实现反向求导逻辑,挺奇怪的。

因为本地和服务器的代码是相同的,只有服务器端报错,所以首先判断是环境的问题。在本地开发使用的tf版本是tf2.2+python3.6,服务器使用的是tf1.15+python3.6。所以首先考虑将本地环境切换至tf1.15+python3.6,复现该问题。切换了环境后果然复现了该问题。网上Google了下该问题,知乎用户smallsunsun说是没有实现反向传播的逻辑,竟然会发生这种事情???看来应该看下tf反向求导的代码逻辑

因为时间紧张,所以打算先绕过这个问题,既然tf1.15的tf.scatter_update()没有实现反向求导逻辑,那么一个合理的思路是用别的函数实现tf.scatter_update()的逻辑。stackoverflow用户Dmytro Prylipko给出了代码,如下所示。

1
2
3
4
5
6
7
8
def scatter_update_tensor(x, indices, updates):                               
'''
Utility function similar to `tf.scatter_update`, but performing on Tensor
'''
x_shape = tf.shape(x)
patch = tf.scatter_nd(indices, updates, x_shape)
mask = tf.greater(tf.scatter_nd(indices, tf.ones_like(updates), x_shape), 0)
return tf.where(mask, patch, x)

尝试了下是OK的(可以反向传播)。但是要注意,tf.scatter_nd()第一个参数的shape维度必须>=2,下面给出tf.scatter_nd()的函数声明和使用示例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
tf.scatter_nd(
indices, updates, shape, name=None
)

indices = tf.constant([[4], [3], [1], [7]]) # 是二维的
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
# [0, 11, 0, 10, 9, 0, 0, 12]

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
# [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
# [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
# [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
# [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]

下面代码测试了函数scatter_update_tensor()和函数tf.scatter_update()的功能,结果是相同的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def scatter_update_tensor(x, indices, updates):
'''
Utility function similar to `tf.scatter_update`, but performing on Tensor
'''
x_shape = tf.shape(x)
patch = tf.scatter_nd(indices, updates, x_shape)
mask = tf.greater(tf.scatter_nd(indices, tf.ones_like(updates), x_shape), 0)
return tf.where(mask, patch, x)

with tf.compat.v1.Session() as sess:
x = tf.Variable(tf.ones(shape=[4,2]))
indices = tf.constant([0, 1, 2, 3], dtype=tf.int32)
updates = tf.reshape(tf.range(8, dtype=tf.float32), [4,2])
sess.run(tf.global_variables_initializer())
print("scatter_update_tensor:")
print(sess.run(scatter_update_tensor(x, tf.reshape(indices, [-1,1]), updates)))
print("tf.scatter_update:")
print(sess.run(tf.scatter_update(x, indices, updates)))
# scatter_update_tensor:
# [[0. 1.]
# [2. 3.]
# [4. 5.]
# [6. 7.]]
# tf.scatter_update:
# [[0. 1.]
# [2. 3.]
# [4. 5.]
# [6. 7.]]

参考