import gc import torch def empty_memory(): ''' Performs garbage collection and empty cache in cuda device. ''' gc.collect() torch.cuda.empty_cache() def tune_for_low_memory(): ''' Tunes PyTorch to use float16 to reduce memory footprint. ''' torch.set_default_dtype(torch.float16)