Apr, 2024

降低 PyTorch 对选择性微分的内存消耗

TL;DR在深度学习任务中,内存是一个有限的资源。除了神经网络权重外,另一个主要消耗内存的因素是通过自动微分(AD)建立的计算图,用于反向传播。我们观察到 PyTorch 当前的 AD 实现在存储计算图时忽略了参数可微性的信息。然而,这些信息对于在许多现代微调任务中请求参数子集的梯度时减少内存是有用的。特别是,对于在参数上线性操作的层(如全连接、卷积或归一化层)的输入可以在参数被标记为不可微时丢弃。我们提供了一个适用于不可微性的层的可替换的实现,并展示它在不影响运行时间的情况下如何减少内存。