이 메서드 zero_grad()
는 훈련 중에 호출해야합니다. 그러나 문서 는별로 도움이되지 않습니다.
| zero_grad(self)
| Sets gradients of all model parameters to zero.
이 메서드를 호출해야하는 이유는 무엇입니까?
이 메서드 zero_grad()
는 훈련 중에 호출해야합니다. 그러나 문서 는별로 도움이되지 않습니다.
| zero_grad(self)
| Sets gradients of all model parameters to zero.
이 메서드를 호출해야하는 이유는 무엇입니까?
답변:
에서는 PyTorch
PyTorch 가 후속 역방향 패스 에서 그라디언트 를 축적 하기 때문에 역 전파를 시작하기 전에 그라디언트를 0으로 설정해야 합니다. 이것은 RNN을 훈련하는 동안 편리합니다. 따라서 기본 작업은 모든 호출 에서 그라디언트 를 누적 (즉, 합계)하는 것loss.backward()
입니다.
따라서 훈련 루프를 시작할 때 이상적으로 zero out the gradients
는 매개 변수 업데이트를 올바르게 수행 해야 합니다. 그렇지 않으면 기울기가 의도 한 방향이 아닌 최소 방향 (또는 최대화 목표의 경우 최대 )을 가리 킵니다 .
다음은 간단한 예입니다.
import torch
from torch.autograd import Variable
import torch.optim as optim
def linear_model(x, W, b):
return torch.matmul(x, W) + b
data, targets = ...
W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)
optimizer = optim.Adam([W, b])
for sample, target in zip(data, targets):
# clear out the gradients of all Variables
# in this optimizer (i.e. W, b)
optimizer.zero_grad()
output = linear_model(sample, W, b)
loss = (output - target) ** 2
loss.backward()
optimizer.step()
또는 바닐라 경사 하강 법을 수행하는 경우 :
W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)
for sample, target in zip(data, targets):
# clear out the gradients of Variables
# (i.e. W, b)
W.grad.data.zero_()
b.grad.data.zero_()
output = linear_model(sample, W, b)
loss = (output - target) ** 2
loss.backward()
W -= learning_rate * W.grad.data
b -= learning_rate * b.grad.data
참고 : 텐서 에서가 호출 될 때 그래디언트 의 누적 (즉, 합계 )이 발생합니다 ..backward()
loss
zero_grad ()는 오류 (또는 손실)를 줄이기 위해 그래디언트 메서드를 사용하는 경우 마지막 단계에서 손실없이 루프를 다시 시작합니다.
zero_grad ()를 사용하지 않으면 필요에 따라 손실이 증가하지 않고 감소합니다.
예를 들어 zero_grad ()를 사용하면 다음 출력을 찾을 수 있습니다.
model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2
zero_grad ()를 사용하지 않으면 다음 출력을 찾을 수 있습니다.
model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5