本文简要谈一下深度学习模型的空间复杂度,初学者在第一次写神经网络模型时很容易忽视空间复杂度导致内存溢出。

最近因为工作需要尝试一些深度学习的模型,之前少有做深度学习的经验,所以踩了一些坑。刚开始写了一个简单的DNN模型,然后放到集群上跑,很快跑出了结果,于是就例行化了。

在例行化第一个版本的时候,按照默认设置,申请了大概10g的内存。但是集群是动态扩容的,在模型训练过程中会根据需要自动扩充内存。由于第一次写深度学习模型,并没有去估计实际的空间复杂度,所以对其空间复杂度并不清晰。模型实际用了大概100g+的内存,但是我是无感的,集群自动帮我做了扩容。

最近因为效果不好排查问题,发现模型好久没有跑出来了,看了下日志,并没有发现有异常。一度以为是同事占用的资源太多被影响了,于是减少了训练集发现还是不行。然后我询问了运维。运维看了下日志,说:是宿主机的内存不够了,gpu虚拟集群 会动态根据程序的需求进行动态扩容,如果宿主机还存在剩余的内存,那么进行扩容。如果内存不足,则无法扩容。在宿主机内存耗尽之后监控进程会根据需要杀死超出预设内存最大的任务,且不会自动恢复。由于我的程序超出了大概100g,所以被杀死了。

之所以我的模型占用了很大内存,是因为对特征进行one-hot之后空间复杂度陡增,仅特征大小就是40w*7w*4byte=104g,无疑是超了,我都怀疑之前是怎么跑完模型的。解决:

  1. 按需one-hot。为了方便,还是直接加载全部特征到内存,分批次梯度求导的时候,对当前批次进行one-hot,设置batch小一些,大概500-1000左右。设置batch=500,特征大小为0.13g,参数数量基本上也是这个量级,估计0.5g的空间复杂度,基本上符合内存需求。
  2. 也考虑过使用稀疏矩阵的方式,这样更具泛化性。但是没有找到怎么获取稀疏矩阵的子矩阵,所以没有采用该方式。在tensorflow中,类SparseTensor表是一个稀疏张量。函数为 SparseTensor(values, indices, dense_shape):
    • indices: 一个二维的张量,数据类型是int64,数据维度是[N, ndims]。
    • values: 一个一维的张量,数据类型是任意的,数据维度是[N]。
    • dense_shape: 一个一维的张量,数据类型是int64,数据维度是[ndims]。其中,N表示稀疏张量中存在N个值,ndims表示SparseTensor的维度。
  3. 因为特征很稀疏,所以也可以考虑使用特征矩阵非零列id 乘 权重矩阵对应的行。个人猜想,tf.sparse_tensor_dense_matmul 可能使用了相似的思想,之后研究一下。

附上通过list生成稀疏矩阵的代码,来自GCN作者,github可搜到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 生成稀疏向量的元组表示
def sparse_to_tuple(sparse_mx):
sparse_mx = sp.csc_matrix(sparse_mx)
"""Convert sparse matrix to tuple representation."""
def to_tuple(mx):
if not sp.isspmatrix_coo(mx):
mx = mx.tocoo()
coords = np.vstack((mx.row, mx.col)).transpose()
values = mx.data
shape = mx.shape
return coords, values, shape
print("sparse_mx shape:")
print(sparse_mx.shape) # (2708, 1433)

if isinstance(sparse_mx, list):
for i in range(len(sparse_mx)):
sparse_mx[i] = to_tuple(sparse_mx[i])
else:
sparse_mx = to_tuple(sparse_mx)

return sparse_mx

data = [[1,0,0],[0,0,2]]
sparse_data = sparse_to_tuple(data)
sess=tf.Session()

# 根据稀疏向量的元组表示生成稀疏向量
a=tf.SparseTensor(sparse_data[0], sparse_data[1], sparse_data[2])
print(sess.run(a))

接下来专攻一下tensorflow语法及深度学习模型,顺便看下pytorch。立flag为证!

相关文献

  1. Tensorflow Python API 翻译(sparse_ops)这篇文章很不错。