피클은 직렬화와 파이썬 객체를 드 - 직렬화 파이썬 라이브러리가 구현하는 바이너리 프로토콜을.
때 당신 import torch
이 (또는 당신은 PyTorch를 사용할 때) import pickle
당신을 위해 당신이 호출 할 필요는 없습니다 pickle.dump()
및 pickle.load()
방법은 저장하고 개체를로드 할 수있는 직접.
사실, torch.save()
그리고 torch.load()
포장 것 pickle.dump()
및 pickle.load()
당신을 위해.
state_dict
다른 대답이 언급은 단지 몇 메모를 가치가있다.
무엇을 state_dict
우리는 PyTorch 내부해야합니까? 실제로 두 가지 state_dict
가 있습니다.
PyTorch 모델이되고 torch.nn.Module
있다 model.parameters()
학습 가능 매개 변수 (w와 b)를 얻기 위해 전화를. 이 학습 가능한 매개 변수는 무작위로 설정되면 학습하면서 시간이 지남에 따라 업데이트됩니다. 학습 가능한 매개 변수가 첫 번째 state_dict
입니다.
두 번째 state_dict
는 최적화 상태 dict입니다. 학습자는 학습 가능한 매개 변수를 개선하는 데 사용됩니다. 그러나 옵티마이 저는 state_dict
고정되어 있습니다. 거기에서 배울 것이 없습니다.
state_dict
객체는 Python 사전 이므로 PyTorch 모델 및 최적화 프로그램에 많은 모듈성을 추가하여 쉽게 저장, 업데이트, 변경 및 복원 할 수 있습니다.
이것을 설명하기 위해 매우 간단한 모델을 만들어 봅시다 :
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
이 코드는 다음을 출력합니다 :
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
이것은 최소 모델입니다. 순차적 스택을 추가하려고 할 수 있습니다
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
학습 가능한 매개 변수 (콘볼 루션 레이어, 선형 레이어 등)가있는 레이어와 등록 된 버퍼 (배치 노어 레이어) 만 모델의 항목에 포함 state_dict
됩니다.
학습 할 수없는 것은 옵티 마이저 객체에 속하며 state_dict
옵티마이 저의 상태 및 사용 된 하이퍼 파라미터에 대한 정보가 들어 있습니다.
이야기의 나머지 부분은 동일합니다. 예측 단계 (추론 후 모델을 사용할 때의 단계)에서 예측; 우리는 우리가 배운 매개 변수를 기반으로 예측합니다. 따라서 추론을 위해 매개 변수를 저장하면됩니다 model.state_dict()
.
torch.save(model.state_dict(), filepath)
나중에 사용하는 model.load_state_dict (torch.load (filepath)) model.eval ()
참고 : model.eval()
모델을로드 한 후 마지막 줄을 잊지 마십시오 .
또한 저장하지 마십시오 torch.save(model.parameters(), filepath)
. 는 model.parameters()
단지 발전기 개체입니다.
다른 한편으로, torch.save(model, filepath)
모델 객체 자체를 저장하지만 모델에는 옵티마이 저가 없습니다 state_dict
. @Jadiel de Armas의 다른 훌륭한 답변을 확인하여 최적화 프로그램의 상태 표시를 저장하십시오.