핸드폰 카메라로 촬영한 데이터셋을 사용하다보니 (약 30GB)
데이터를 한번 학습시키는데도 한시간은 걸리는 것 같다.
그래서 loss가 갱신될 때 마다 해당 모델을 저장하도록 만들었다.
그 중 학습한 모델을 저장하는 부분이다.
소스코드
1
|
torch.save(model.state_dict(), 'model_transfer.pth')
|
cs |
딱 한줄이면 된다. EZ
이렇게하면 model_trasnfer.pth라는 파일에 학습한 모델이 저장된다.
확장자명은 pt나 pth를 사용한다.
1
|
model_transfer.load_state_dict((torch.load('model_transfer.pth')))
|
cs |
불러오는 것도 간단하다.
load_state_dict함수를 사용하면 된다.
하지만, 모델의 구조는 똑같이 만들어줘야 한다.
여기서 추가 학습을 하고 싶으면 model_trasnfer.train()을 사용하고,
학습한 모델로 예측을 하고싶으면 model_trasnfer.eval()을 사용하면 된다.
'개발 > 파이썬' 카테고리의 다른 글
[Python] YOLOv5 Custom dataset 으로 학습하기 (26) | 2020.09.17 |
---|---|
[Python] OpenCV로 이미지 배경 제거하기 (0) | 2020.09.16 |
[Python] MNIST 예제를 통한 Keras, PyTorch 비교 (2) | 2020.08.25 |
[Python] 케라스(Keras)를 사용한 MNIST 문자인식 구현 예제 (0) | 2020.08.25 |
[Python] 파이토치(PyTorch)를 사용한 MNIST 문자인식 구현 예제 (0) | 2020.08.25 |