介绍

本文介绍了使用混合精度训练验证禁用梯度来优化显存的占用。根据笔者实测,混合精度训练对网络的影响几乎可以忽略不及,但是显存可以降低一半以上。

混合精度训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
scaler = GradScaler()

optimizer = optim.SGD(model.parameters(), lr=0.04, momentum=0.7, weight_decay=5e-4)

for epoch in range(0, n_epochs):
train_loss = 0.0
valid_loss = 0.0
model.train()
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
with autocast():
output = model(data).to(device)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  1. 初始化梯度放大器 scaler = GradScaler()
  2. 在模型推理部分加上 with autocast()
  3. 使用Scale防止半精度发生的数据溢出!这点非常重要
  4. 遗憾的是并没有发现存在加速效果。

禁用梯度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
with torch.no_grad():
for data, target in test_loader:
data = data.to(device)
target = target.to(device)
output = model(data).to(device)
loss = criterion(output, target)
valid_loss += loss.item()*data.size(0)
_, pred = torch.max(output, 1)
correct_tensor = pred.eq(target.data.view_as(pred))
total_sample += data.size(0)
right_sample += list(correct_tensor).count(True)
print()
print("Accuracy:",100*right_sample/total_sample,"%")
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
epoch, train_loss, valid_loss))
# 如果验证集损失函数减少,就保存模型。
if valid_loss <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving Person ...'.format(valid_loss_min,valid_loss))
# torch.save(model.state_dict(), 'resnet18_cifar10.pt')
valid_loss_min = valid_loss

开启禁用梯度是因为,pytorch即便在不调用backward的情况下也会存在梯度,这导致了测试也带来了大量的显存占用,通过torch.no_grad()来禁止梯度,可以有效防止显存爆炸。

后记

又发现有时候并不降低显存使用,降低了GPU负载。