개발/파이썬
[Python] 파이토치(PyTorch) 학습한 모델 저장 & 불러오기
크리쓰마스
2020. 9. 8. 17:35
핸드폰 카메라로 촬영한 데이터셋을 사용하다보니 (약 30GB)
데이터를 한번 학습시키는데도 한시간은 걸리는 것 같다.
그래서 loss가 갱신될 때 마다 해당 모델을 저장하도록 만들었다.
그 중 학습한 모델을 저장하는 부분이다.
소스코드
1
|
torch.save(model.state_dict(), 'model_transfer.pth')
|
cs |
딱 한줄이면 된다. EZ
이렇게하면 model_trasnfer.pth라는 파일에 학습한 모델이 저장된다.
확장자명은 pt나 pth를 사용한다.
1
|
model_transfer.load_state_dict((torch.load('model_transfer.pth')))
|
cs |
불러오는 것도 간단하다.
load_state_dict함수를 사용하면 된다.
하지만, 모델의 구조는 똑같이 만들어줘야 한다.
여기서 추가 학습을 하고 싶으면 model_trasnfer.train()을 사용하고,
학습한 모델로 예측을 하고싶으면 model_trasnfer.eval()을 사용하면 된다.