pytorch에서 torch.no_grad의 사용법은 무엇입니까?


21

나는 pytorch를 처음 사용 하고이 github 코드로 시작했습니다 . 코드의 60-61 행에있는 주석을 이해하지 못합니다 "because weights have requires_grad=True, but we don't need to track this in autograd". requires_grad=Trueautograd를 사용하기 위해 기울기를 계산 해야하는 변수에 대해 언급했지만 이해하는 것은 무엇 "tracked by autograd"입니까?

답변:


24

"with torch.no_grad ()"래퍼는 일시적으로 모든 require_grad 플래그를 false로 설정합니다. 공식 PyTorch 튜토리얼의 예제 ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ) :

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

밖:

True
True
False

위의 웹 사이트에서 모든 자습서를 읽는 것이 좋습니다.

귀하의 예 : 저자는 PyTorch가 값을 업데이트하기 때문에 새로운 정의 변수 w1 및 w2의 그라디언트를 계산하지 않기를 바랍니다.


6
with torch.no_grad()

블록의 모든 작업에 그라디언트가 없도록합니다.

pytorch에서는와 함께 두 변수 인 w1과 w2의 위치 변경을 수행 할 수 없습니다 require_grad = True. w1과 w2의 위치 변경을 피하는 것은 백 전파 계산에 오류가 발생할 것이기 때문이라고 생각합니다. 위치 변경은 w1과 w2를 완전히 바꿉니다.

그러나이를 사용 no_grad()하면 새 w1 및 새 w2가 조작에 의해 생성되므로 그라디언트가없는 것을 제어 할 수 있습니다. 즉, 그라디언트 부분이 아닌 w1 및 w2의 값만 변경하며 이전에 정의 된 변수 그라디언트 정보가 있습니다. 역전 파는 계속 될 수 있습니다.

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.