반응형
안녕하세요
pretrained된 모델을 로드하고 학습 및 추론을 할 때 다음과 같은 에러를 보신적이 한두번은 있을겁니다.
size mismatch가 아닌 없는 layer또는 있는 layer를 무시할 때는
state_dict = torch.load(cached_file)
mdl.load_state_dict(state_dict,strict=False)
위에 코드처럼 load_state_dict()에 strict=False만 기입해주면 해결이 됩니다.
하지만 맨 위에 에러처럼 같은 layer에서의 size mismatch는 strict=False로 해결이 안됩니다...
외국 포럼 사이트에서 해결방법을 찾게되어 소개를 해볼려고 합니다.
우선 코드부터 보드리겠습니다.
def on_load_checkpoint(self, checkpoint: dict) -> None:
state_dict = checkpoint
model_state_dict = self.state_dict()
is_changed = False
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
logging.info(f"Skip loading parameter: {k}, "
f"required shape: {model_state_dict[k].shape}, "
f"loaded shape: {state_dict[k].shape}")
state_dict[k] = model_state_dict[k]
is_changed = True
else:
logging.info(f"Dropping parameter {k}")
is_changed = True
if is_changed:
checkpoint.pop("optimizer_states", None)
해당 메소드를 python이 설치되어있는 경로에서 python3.6/site-packages/torch/nn/modules/module.py에서 Modul클래스 밑에 메소드로 추가해주면 됩니다.
그 후에는 load_state_dict함수가 아닌 on_load_checkpoint를 사용하시면 됩니다~
감사합니다~
반응형
'AI' 카테고리의 다른 글
XML to TXT annotation file format 변환 (0) | 2022.07.26 |
---|---|
Gradient Descent? (0) | 2022.07.24 |
[Pytorch] LENET5 모델 학습 및 추론 코드(마스크 구별 프로그램) (0) | 2021.11.11 |
tensorflow zoo && tensorflow detection api 사용하여 tflite파일 만드는 과정 (0) | 2021.11.10 |
RuntimeError: stack expects each tensor to be equal size, but got ~ (0) | 2021.11.09 |