PyTorchのImageFolderで読み込み済みの画像をキャッシュする

PyTorchでCNNを組んで画像を識別する学習を回していて,ふとプロファイリング(http://shiba6v.hatenablog.com/entry/2018/05/15/215211)をとってみると,

         424113812 function calls (419713376 primitive calls) in 25020.907 seconds

   Ordered by: internal time
   List reduced from 466 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1 8992.740 8992.740 25020.906 25020.906 <ipython-input-2-0a878bac24c8>:186(<module>)
  2661000 6994.380    0.003 6994.380    0.003 {method 'decode' of 'ImagingDecoder' objects}
  7012500 3792.518    0.001 3792.518    0.001 {method 'nonzero' of 'torch._C._TensorBase' objects}
    68750 1436.746    0.021 1436.746    0.021 {method 'run_backward' of 'torch._C._EngineBase' objects}
  1099950 1280.304    0.001 1280.304    0.001 {method 'resize' of 'ImagingCore' objects}
   412500  767.758    0.002 4908.700    0.012 <ipython-input-2-0a878bac24c8>:137(unary)
  8774486  231.978    0.000  231.978    0.000 {method 'mean' of 'torch._C._TensorBase' objects}
    68750  163.959    0.002  190.920    0.003 <ipython-input-2-0a878bac24c8>:144(kde_batch)
  1168700  143.284    0.000  143.284    0.000 {method 'float' of 'torch._C._TensorBase' objects}
  1091900  108.316    0.000  108.316    0.000 {method 'copy' of 'ImagingCore' objects}

ん?? ImagingDecoderがめっちゃ時間を食っている・・・
画像の読み込みのデータセットtorchvision.datasets.folder.ImageFolderを使っていたんですが,どうやら毎回画像ファイルを読みに行っているっぽい・・・?
50epoch回してImagingDecoderが一番時間がかかっているなら,PyTorch側でキャッシュはされていなそうですね.

今回は学習データがメモリに乗り切りそうだったので,データセットを全部メモリにキャッシュしてあげて高速化します.
transformにランダムに切り出すような処理を今回は書いていないので,今回は__getitem__の出力をそのままdictionaryに入れます. ImageFolderと変わらない感じで使えるようにしました.(好きにパクって使ってください.)

class CachedImageFolder(torchvision.datasets.folder.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None,loader=torchvision.datasets.folder.default_loader):
        super(CachedImageFolder, self).__init__(root,transform=transform,target_transform=target_transform,loader = loader)
        self.cache = {}
    def __getitem__(self,index):
        if index in self.cache:
            return self.cache[index]
        item = super(CachedImageFolder,self).__getitem__(index)
        self.cache[index] = item
        return item

CachedImageFolderを使って10epoch程度回した結果,

         22530602 function calls (21791255 primitive calls) in 3017.442 seconds

   Ordered by: internal time
   List reduced from 796 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1 1508.861 1508.861 3008.220 3008.220 <ipython-input-10-ed3ddfcf5be8>:186(<module>)
  1177999  637.229    0.001  637.229    0.001 {method 'nonzero' of 'torch._C._TensorBase' objects}
    11549  346.354    0.030  346.354    0.030 {method 'run_backward' of 'torch._C._EngineBase' objects}
    69295  149.182    0.002  853.208    0.012 <ipython-input-10-ed3ddfcf5be8>:137(unary)
    53220  140.252    0.003  140.252    0.003 {method 'decode' of 'ImagingDecoder' objects}
  1475069   42.256    0.000   42.256    0.000 {method 'mean' of 'torch._C._TensorBase' objects}
    11549   32.591    0.003   37.861    0.003 <ipython-input-10-ed3ddfcf5be8>:144(kde_batch)
    21999   25.677    0.001   25.677    0.001 {method 'resize' of 'ImagingCore' objects}
    11550   20.518    0.002   20.518    0.002 {built-in method stack}
    23100   14.861    0.001   14.861    0.001 {method 'to' of 'torch._C._TensorBase' objects}

おっ,ImagingDecoderの順位が下がっていますね. 元が50epochで25000sec,キャッシュを利用したら10epochで3000secなので,(ちゃんとした比較ではないですが)速くなっていそうですね

これ,もしメモリに載らなくてもSwap領域を増やしたらいけそうな気がしますが,それは必要になったらやりたいと思います. (大きな画像をリサイズするとかなり容量が落ちるのでなんだかんだで載ると思いますが・・・w)

他に便利なモジュールなどがあれば教えてほしいです ><