PyTorch에서 훈련 된 모델을 저장하는 가장 좋은 방법은?


193

PyTorch에서 훈련 된 모델을 저장하는 다른 방법을 찾고있었습니다. 지금까지 두 가지 대안을 찾았습니다.

  1. torch.save () 는 모델을 저장하고 torch.load () 는 모델을 저장합니다 .
  2. 훈련 된 모델을 저장하려면 model.state_dict () , 저장된 모델을로드하려면 model.load_state_dict ()

접근법 1보다 접근법 2가 권장되는 이 토론을 보았습니다 .

내 질문은 왜 두 번째 접근 방식이 선호됩니까? torch.nn 모듈 이이 두 가지 기능을 가지고 있기 때문에 사용합니까?


2
나는 torch.save ()가 역 전파 사용을위한 중간 출력과 같은 모든 중간 변수를 저장하기 때문이라고 생각합니다. 그러나 무게 / 바이어스와 같은 모델 매개 변수 만 저장하면됩니다. 때로는 전자가 후자보다 훨씬 클 수 있습니다.
Dawei Yang

2
나는 시험 torch.save(model, f)torch.save(model.state_dict(), f). 저장된 파일의 크기가 동일합니다. 이제 혼란 스러워요. 또한 pickle을 사용하여 model.state_dict ()를 저장하는 것이 매우 느립니다. torch.save(model.state_dict(), f)모델 생성을 처리하고 토치가 모델 웨이트의 로딩을 처리하므로 가능한 문제를 제거하는 것이 가장 좋은 방법이라고 생각합니다 . 참조 : 토론 .pytorch.org
Dawei Yang

PyTorch가 자습서 섹션 에서이 문제를 좀 더 명확하게 해결 한 것 같습니다. 한 번에 두 개 이상의 모델 저장 및 웜 스타트 모델을 포함하여 여기에 답변에 나열되지 않은 좋은 정보가 많이 있습니다.
whlteXbread 2016 년

사용에 어떤 문제가 pickle있습니까?
찰리 파커

1
@CharlieParker torch.save는 피클을 기반으로합니다. "[torch.save]는 Python의 pickle 모듈을 사용하여 전체 모듈을 저장합니다.이 방법의 단점은 직렬화 된 데이터가 모델에 사용될 때 사용되는 특정 클래스 및 정확한 디렉토리 구조에 바인딩되어 있다는 것입니다 pickle은 모델 클래스 자체를 저장하지 않기 때문에로드 시간 동안 사용되는 클래스를 포함하는 파일의 경로를 저장하기 때문에 코드가 여러 가지 방식으로 중단 될 수 있습니다. 다른 프로젝트 또는 리팩터링 후에 사용됩니다. "
David Miller

답변:


215

github 저장소 에서이 페이지 를 찾았 습니다. 여기에 내용을 붙여 넣을 것입니다.


모델 저장을위한 권장 접근법

모델을 직렬화하고 복원하는 데는 두 가지 주요 접근 방식이 있습니다.

첫 번째 (권장)는 모델 매개 변수 만 저장하고로드합니다.

torch.save(the_model.state_dict(), PATH)

그런 다음 나중에 :

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

두 번째는 전체 모델을 저장하고로드합니다.

torch.save(the_model, PATH)

그런 다음 나중에 :

the_model = torch.load(PATH)

그러나이 경우 직렬화 된 데이터는 사용 된 특정 클래스 및 정확한 디렉토리 구조에 바인딩되므로 다른 프로젝트에서 사용하거나 심각한 리팩터링 후에 다양한 방식으로 중단 될 수 있습니다.


8
@smth에 따르면 discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/... 모델을 다시로드하는 것은 기본적으로 모델을 학습합니다. 따라서 학습을 재개하지 않고 추론을 위해로드하는 경우로드 후 수동으로 the_model.eval ()을 호출해야합니다.
WillZ

두 번째 방법은 stackoverflow.com/questions/53798009/… Windows 10에서 오류를 발생시킵니다.이 문제를 해결할 수 없습니다
Gulzar

모델 클래스에 액세스 할 필요없이 저장할 수있는 옵션이 있습니까?
Michael D

이 접근 방식을 사용하면로드 사례에 대해 전달해야하는 * args 및 ** kwargs를 어떻게 추적합니까?
Mariano Kamp

사용에 어떤 문제가 pickle있습니까?
찰리 파커

144

그것은 당신이하고 싶은 것에 달려 있습니다.

사례 # 1 : 추론에 사용하기 위해 모델 저장 : 모델을 저장하고 복원 한 다음 모델을 평가 모드로 변경합니다. 이것은 일반적 으로 건설시 기본적으로 열차 모드에 BatchNorm있고 Dropout레이어 가 있기 때문에 수행됩니다 .

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

사례 # 2 : 나중에 교육을 다시 시작하기 위해 모델 저장 : 저장하려는 모델을 계속 교육 해야하는 경우 모델 이상의 것을 저장해야합니다. 또한 옵티 마이저, 에포크, 스코어 등의 상태를 저장해야합니다. 다음과 같이하십시오.

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

훈련을 재개하려면 state = torch.load(filepath)다음과 같은 작업을 수행 한 다음 각 개별 객체의 상태를 복원하십시오.

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

훈련을 재개하고 있으므로 로드 할 때 상태를 복원 한 후에는 전화 하지 마십시오model.eval() .

사례 # 3 : 코드에 액세스 할 수없는 다른 사람이 사용할 모델 : Tensorflow .pb에서 모델의 아키텍처와 가중치를 모두 정의 하는 파일을 만들 수 있습니다 . 이것은 특히 사용할 때 매우 편리합니다 Tensorflow serve. Pytorch에서이를 수행하는 동등한 방법은 다음과 같습니다.

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

이 방법은 여전히 ​​방탄이 아니며 pytorch는 여전히 많은 변화를 겪고 있기 때문에 권장하지 않습니다.


1
세 가지 경우에 권장되는 파일 종료가 있습니까? 아니면 항상 .pth입니까?
Verena Haunschmid

1
사례 # 3 torch.load에서는 OrderedDict 만 반환합니다. 예측을하기 위해 모델을 어떻게 얻습니까?
Alber8295

안녕하세요. 언급 된 "사례 # 2 : 모델을 저장하여 나중에 훈련을 다시 시작"하는 방법을 알고 있습니까? 검사 점을 모델에로드 한 다음 "model.to (device) model = train_model_epoch (model, criterion, optimizer, sched, epochs)"와 같은 모델을 실행하거나 다시 시작할 수 없습니다.
dnez

1
안녕하세요, 추론의 경우 공식 pytorch 문서에서 추론이나 훈련을 완료하기 위해 옵티 마이저 state_dict를 저장해야한다고 말합니다. "추론 또는 재개에 사용하기 위해 일반 체크 포인트를 저장하는 경우 모델의 state_dict 이상을 저장해야합니다. 여기에는 모델 트레인으로 업데이트되는 버퍼 및 매개 변수가 포함되어 있으므로 옵티마이 저의 state_dict도 저장해야합니다. "
Mohammed Awney

1
# 3의 경우 모델 클래스를 어딘가에 정의해야합니다.
Michael D

12

피클은 직렬화와 파이썬 객체를 드 - 직렬화 파이썬 라이브러리가 구현하는 바이너리 프로토콜을.

때 당신 import torch이 (또는 당신은 PyTorch를 사용할 때) import pickle당신을 위해 당신이 호출 할 필요는 없습니다 pickle.dump()pickle.load()방법은 저장하고 개체를로드 할 수있는 직접.

사실, torch.save()그리고 torch.load()포장 것 pickle.dump()pickle.load()당신을 위해.

state_dict다른 대답이 언급은 단지 몇 메모를 가치가있다.

무엇을 state_dict우리는 PyTorch 내부해야합니까? 실제로 두 가지 state_dict가 있습니다.

PyTorch 모델이되고 torch.nn.Module있다 model.parameters()학습 가능 매개 변수 (w와 b)를 얻기 위해 전화를. 이 학습 가능한 매개 변수는 무작위로 설정되면 학습하면서 시간이 지남에 따라 업데이트됩니다. 학습 가능한 매개 변수가 첫 번째 state_dict입니다.

두 번째 state_dict는 최적화 상태 dict입니다. 학습자는 학습 가능한 매개 변수를 개선하는 데 사용됩니다. 그러나 옵티마이 저는 state_dict고정되어 있습니다. 거기에서 배울 것이 없습니다.

state_dict객체는 Python 사전 이므로 PyTorch 모델 및 최적화 프로그램에 많은 모듈성을 추가하여 쉽게 저장, 업데이트, 변경 및 복원 할 수 있습니다.

이것을 설명하기 위해 매우 간단한 모델을 만들어 봅시다 :

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

이 코드는 다음을 출력합니다 :

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

이것은 최소 모델입니다. 순차적 스택을 추가하려고 할 수 있습니다

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

학습 가능한 매개 변수 (콘볼 루션 레이어, 선형 레이어 등)가있는 레이어와 등록 된 버퍼 (배치 노어 레이어) 만 모델의 항목에 포함 state_dict됩니다.

학습 할 수없는 것은 옵티 마이저 객체에 속하며 state_dict옵티마이 저의 상태 및 사용 된 하이퍼 파라미터에 대한 정보가 들어 있습니다.

이야기의 나머지 부분은 동일합니다. 예측 단계 (추론 후 모델을 사용할 때의 단계)에서 예측; 우리는 우리가 배운 매개 변수를 기반으로 예측합니다. 따라서 추론을 위해 매개 변수를 저장하면됩니다 model.state_dict().

torch.save(model.state_dict(), filepath)

나중에 사용하는 model.load_state_dict (torch.load (filepath)) model.eval ()

참고 : model.eval()모델을로드 한 후 마지막 줄을 잊지 마십시오 .

또한 저장하지 마십시오 torch.save(model.parameters(), filepath). 는 model.parameters()단지 발전기 개체입니다.

다른 한편으로, torch.save(model, filepath)모델 객체 자체를 저장하지만 모델에는 옵티마이 저가 없습니다 state_dict. @Jadiel de Armas의 다른 훌륭한 답변을 확인하여 최적화 프로그램의 상태 표시를 저장하십시오.


간단한 해결책은 아니지만 문제의 본질을 깊이 분석합니다! 공감.
Jason Young

7

일반적인 PyTorch 규칙은 .pt 또는 .pth 파일 확장자를 사용하여 모델을 저장하는 것입니다.

전체 모델 저장 /로드 저장 :

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

하중:

모델 클래스는 어딘가에 정의되어야합니다

model = torch.load(PATH)
model.eval()

4

모델을 저장하고 나중에 교육을 다시 시작하려는 경우 :

단일 GPU : 저장 :

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

하중:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

다중 GPU : 저장

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

하중:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.