feat: add hash_cache

This commit is contained in:
Manuel Schmid 2024-07-01 17:20:20 +02:00
parent 9178aa8ebb
commit 33b1c5cb87
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
4 changed files with 52 additions and 14 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ __pycache__
*.partial
*.onnx
sorted_styles.json
hash_cache.json
/input
/cache
/language/default.json

View File

@ -7,6 +7,7 @@ import args_manager
import tempfile
import modules.flags
import modules.sdxl_styles
from modules.hash_cache import load_cache_from_file, save_cache_to_file
from modules.model_loader import load_file_from_url
from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var
@ -755,3 +756,6 @@ def downloading_safety_checker_model():
update_files()
load_cache_from_file()
# write cache to file again for cleanup of invalid cache entries
save_cache_to_file()

41
modules/hash_cache.py Normal file
View File

@ -0,0 +1,41 @@
from modules.util import sha256, HASH_SHA256_LENGTH
import os
import json
hash_cache_filename = 'hash_cache.json'
hash_cache = {}
def sha256_from_cache(filepath):
global hash_cache
if filepath not in hash_cache:
hash_cache[filepath] = sha256(filepath)
save_cache_to_file()
return hash_cache[filepath]
def load_cache_from_file():
global hash_cache
try:
if os.path.exists(hash_cache_filename):
with open(hash_cache_filename, 'rt', encoding='utf-8') as fp:
for filepath, hash in json.load(fp).items():
if not os.path.exists(filepath) or not isinstance(hash, str) and len(hash) != HASH_SHA256_LENGTH:
print(f'[Cache] Skipping invalid cache entry: {filepath}')
continue
hash_cache[filepath] = hash
print(f'[Cache] Warmed cache from file')
except Exception as e:
print(f'[Cache] Warming failed: {e}')
def save_cache_to_file():
global hash_cache
try:
with open(hash_cache_filename, 'wt', encoding='utf-8') as fp:
json.dump(hash_cache, fp, indent=4)
print(f'[Cache] Updated cache file')
except Exception as e:
print(f'[Cache] Saving failed: {e}')

View File

@ -9,16 +9,16 @@ from PIL import Image
import fooocus_version
import modules.config
import modules.sdxl_styles
from modules import hash_cache
from modules.flags import MetadataScheme, Performance, Steps
from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, sha256
from modules.hash_cache import sha256_from_cache
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
hash_cache = {}
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
loaded_parameter_dict = raw_metadata
@ -215,14 +215,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, p
results.append(1)
def get_sha256(filepath):
global hash_cache
if filepath not in hash_cache:
hash_cache[filepath] = sha256(filepath)
return hash_cache[filepath]
def parse_meta_from_preset(preset_content):
assert isinstance(preset_content, dict)
preset_prepared = {}
@ -290,18 +282,18 @@ class MetadataParser(ABC):
self.base_model_name = Path(base_model_name).stem
base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints)
self.base_model_hash = get_sha256(base_model_path)
self.base_model_hash = sha256_from_cache(base_model_path)
if refiner_model_name not in ['', 'None']:
self.refiner_model_name = Path(refiner_model_name).stem
refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints)
self.refiner_model_hash = get_sha256(refiner_model_path)
self.refiner_model_hash = sha256_from_cache(refiner_model_path)
self.loras = []
for (lora_name, lora_weight) in loras:
if lora_name != 'None':
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
lora_hash = get_sha256(lora_path)
lora_hash = sha256_from_cache(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem