try fix mps (#381)
This commit is contained in:
parent
cf7cde08b1
commit
d1b4389098
|
|
@ -1 +1 @@
|
|||
version = '2.0.16'
|
||||
version = '2.0.17'
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import torch
|
||||
|
||||
import comfy.model_management as model_management
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||
|
|
@ -29,15 +31,20 @@ class FooocusExpansion:
|
|||
def __init__(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
|
||||
|
||||
if model_management.should_use_fp16():
|
||||
self.model.half()
|
||||
self.model.eval()
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
|
||||
if 'mps' in load_device.type:
|
||||
load_device = torch.device('cpu')
|
||||
|
||||
if 'cpu' not in load_device.type and model_management.should_use_fp16():
|
||||
self.model.half()
|
||||
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
print(f'Fooocus Expansion engine loaded.')
|
||||
print(f'Fooocus Expansion engine loaded for {load_device}.')
|
||||
|
||||
def __call__(self, prompt, seed):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
|
|
|
|||
Loading…
Reference in New Issue