解决tf1.15中tf.scatter_update()函数没有定义梯度的问题
在本地使用了tf.compat.v1.scatter_update()更新tensor的值,测试OK放到服务器上跑一直失败,报错LookupError: No gradient defined for operation 'ScatterUpdate' (op type: ScatterUpdate)
,最后发现疑似tf1.15的tf.scatter_update()没有实现反向求导逻辑,挺奇怪的。
因为本地和服务器的代码是相同的,只有服务器端报错,所以首先判断是环境的问题。在本地开发使用的tf版本是tf2.2+python3.6,服务器使用的是tf1.15+python3.6。所以首先考虑将本地环境切换至tf1.15+python3.6,复现该问题。切换了环境后果然复现了该问题。网上Google了下该问题,知乎用户smallsunsun说是没有实现反向传播的逻辑,竟然会发生这种事情???看来应该看下tf反向求导的代码逻辑。
因为时间紧张,所以打算先绕过这个问题,既然tf1.15的tf.scatter_update()没有实现反向求导逻辑,那么一个合理的思路是用别的函数实现tf.scatter_update()的逻辑。stackoverflow用户Dmytro Prylipko给出了代码,如下所示。
1 | def scatter_update_tensor(x, indices, updates): |
尝试了下是OK的(可以反向传播)。但是要注意,tf.scatter_nd()第一个参数的shape维度必须>=2,下面给出tf.scatter_nd()的函数声明和使用示例。
1 | tf.scatter_nd( |
下面代码测试了函数scatter_update_tensor()和函数tf.scatter_update()的功能,结果是相同的。
1 | def scatter_update_tensor(x, indices, updates): |