PyTorch에서 왜 zero_grad ()를 호출해야합니까?


답변:


144

에서는 PyTorchPyTorch 가 후속 역방향 패스 에서 그라디언트축적 하기 때문에 역 전파를 시작하기 전에 그라디언트를 0으로 설정해야 합니다. 이것은 RNN을 훈련하는 동안 편리합니다. 따라서 기본 작업은 모든 호출 에서 그라디언트누적 (즉, 합계)하는 것loss.backward() 입니다.

따라서 훈련 루프를 시작할 때 이상적으로 zero out the gradients는 매개 변수 업데이트를 올바르게 수행 해야 합니다. 그렇지 않으면 기울기가 의도 한 방향이 아닌 최소 방향 (또는 최대화 목표의 경우 최대 )을 가리 킵니다 .

다음은 간단한 예입니다.

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

또는 바닐라 경사 하강 법을 수행하는 경우 :

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

참고 : 텐서 에서가 호출 될 때 그래디언트 의 누적 (즉, 합계 )이 발생합니다 ..backward()loss


3
대단히 감사합니다. 이것은 정말 도움이됩니다! tensorflow에 동작이 있는지 알고 있습니까?
layser

확실하게 .. 이렇게하지 않으면 폭발적인 그라디언트 문제가 발생합니다. 그렇죠?
zwep

2
@zwep 그래디언트를 누적한다고해서 크기가 증가하는 것은 아닙니다. 그래디언트의 부호가 계속 뒤집히는 경우가 그 예입니다. 따라서 폭발적인 그래디언트 문제가 발생한다는 보장은 없습니다. 게다가, 정확하게 제로를 설정하더라도 폭발하는 그라디언트가 존재합니다.
Tom Roth

바닐라 경사 하강 법을 실행할 때 가중치를 업데이트하려고 할 때 "grad가 필요한 리프 변수가 제자리 작업에 사용되었습니다"오류가 표시되지 않습니까?
MUAS

1

zero_grad ()는 오류 (또는 손실)를 줄이기 위해 그래디언트 메서드를 사용하는 경우 마지막 단계에서 손실없이 루프를 다시 시작합니다.

zero_grad ()를 사용하지 않으면 필요에 따라 손실이 증가하지 않고 감소합니다.

예를 들어 zero_grad ()를 사용하면 다음 출력을 찾을 수 있습니다.

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

zero_grad ()를 사용하지 않으면 다음 출력을 찾을 수 있습니다.

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.