핸드폰 카메라로 촬영한 데이터셋을 사용하다보니 (약 30GB)

데이터를 한번 학습시키는데도 한시간은 걸리는 것 같다.

그래서 loss가 갱신될 때 마다 해당 모델을 저장하도록 만들었다.

그 중 학습한 모델을 저장하는 부분이다.


소스코드

1
torch.save(model.state_dict(), 'model_transfer.pth')
cs

 

딱 한줄이면 된다. EZ

이렇게하면 model_trasnfer.pth라는 파일에 학습한 모델이 저장된다.

확장자명은 pt나 pth를 사용한다.

 

1
model_transfer.load_state_dict((torch.load('model_transfer.pth')))
cs

 

불러오는 것도 간단하다.

load_state_dict함수를 사용하면 된다.

하지만, 모델의 구조는 똑같이 만들어줘야 한다.

 

여기서 추가 학습을 하고 싶으면 model_trasnfer.train()을 사용하고,

학습한 모델로 예측을 하고싶으면 model_trasnfer.eval()을 사용하면 된다. 

+ Recent posts