stable-diffusion-for-fun/utilities/memory.py

18 lines
320 B
Python

import gc
import torch
def empty_memory():
"""
Performs garbage collection and empty cache in cuda device.
"""
gc.collect()
torch.cuda.empty_cache()
def tune_for_low_memory():
"""
Tunes PyTorch to use float16 to reduce memory footprint.
"""
torch.set_default_dtype(torch.float16)