REFERENCE

1. prepare_data

  • 데이터 다운로드 등의 기능을 수행할 메서드.
  • setup 전에 호출 됨.
  • main process + single process니까 device 별 실행 되야하는 기능은 해당 메서드에서 구현을 피해야 함.
  • DO NOT set state to the model (use setup instead) since this is NOT called on every device
    (e.g. self.split='train')

2. setup


 

아래처럼 깔끔하게 사용할 수도 있지만, Dataset등의 내부 멤버 변수가 필요하면 수동으로 콜해도 문제 없다.

dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)

 

Save

torch.save(model.state_dict(), PATH)

보통 torch.save를 사용해서 pytorch model을 저장하는데 이때 보통 .pt, .pth의 확장자를 쓴다. 그러나 .pth의 경우 python path와 충돌 위험이 있기때문에 .pt 확장자를 사용하는 것을 추천한다. 

Save for resume

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

일반적으로 모델을 저장할 때 학습 재개(resuming)를 위해 모델 파라미터 뿐만 아니라 optimizer, loss, epoch, scheduler등 다양한 정보를 함께 저장한다. 그래서 이러한 checkpoint는 모델만 저장할 때에 비해서 용량이 훨씬 커진다. 이럴때는 .tar 확장자를 사용한다. 나는 주로 .pth.tar를 사용한다.

To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.

reference

'Deep Learning > pytorch' 카테고리의 다른 글

torch.backends.cudnn.benchmark = True 의미  (0) 2021.08.23
[Pytorch] yolo to pytorch(0)  (0) 2021.05.17
# TENSORRT
set(TRT_VERSION 7.2.3.4)
set(TRT_PATH "path to tensorrt"/TensorRT-${TRT_VERSION})

# TensorRT
MESSAGE("\nTensorRT " ${TRT_VERSION})
MESSAGE(STATUS "${TRT_PATH}\n")

set(TRT_INCLUDE_PATH ${TRT_PATH}/include)
set(TRT_LIB_PATH ${TRT_PATH}/lib)
set(TRT_LIBS nvinfer nvonnxparser)
set(TRT_DLLS ${TRT_LIB_PATH}/nvinfer.dll
			 ${TRT_LIB_PATH}/nvonnxparser.dll
			 ${TRT_LIB_PATH}/nvinfer_plugin.dll
			 ${TRT_LIB_PATH}/myelin64_1.dll
			 )
include_directories(${TRT_INCLUDE_PATH})
link_directories(${TRT_LIB_PATH})
link_libraries(${TRT_LIBS})

# do something 
# ...
# ...

# when build excutable
add_custom_command(TARGET ${project}
				   POST_BUILD
				   COMMAND ${CMAKE_COMMAND} -E copy 
				   ${TRT_DLLS}
				   ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${CMAKE_BUILD_TYPE}
)

+ Recent posts