PyTorchのDataParallelのモデルを保存する
PyTorchで複数GPUで学習させる場合,
model = nn.DataParallel(model, device_ids=[0,1,2])
のようにDataParallelで保存しますが,このモデルを保存したい場合にcuda runtime error : out of memoryが出ることがあります.
その場合は,下のようにDataParallelから元のモデルを取り出してCPUのモデルに変えてあげることで保存できるようになります.
torch.save(model.module.cpu(),file_path)
読み込み時はこうすればOK
new_model = torch.load(file_path)
参考
Optional: Data Parallelism — PyTorch Tutorials 1.0.0.dev20181002 documentation