minor revision

Now the GPT will try to (1) use more aligned formatting of commas, (2) a bit more dynamic in word choice, (3) avoid duplication like “detail, detail, detail, detail”
This commit is contained in:
lllyasviel 2023-11-01 12:50:12 -07:00
parent 5dc1221c65
commit 01bfa11fd1
2 changed files with 23 additions and 13 deletions

View File

@ -1 +1 @@
version = '2.1.771'
version = '2.1.772'

View File

@ -3,6 +3,7 @@ import torch
import math
import fcbh.model_management as model_management
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules.path import fooocus_expansion_path
from fcbh.model_patcher import ModelPatcher
@ -10,6 +11,7 @@ from fcbh.model_patcher import ModelPatcher
# limitation of np.random.seed(), called from transformers.set_seed()
SEED_LIMIT_NUMPY = 2**32
neg_inf = - 8192.0
def safe_str(x):
@ -31,20 +33,17 @@ class FooocusExpansion:
positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'),
encoding='utf-8').read().splitlines()
positive_words = [x.lower() for x in positive_words if x != '']
positive_words = ['Ġ' + x for x in positive_words if x != '']
# new_content = '\n'.join(sorted(list(set(positive_words))))
# eos = self.tokenizer.eos_token_id
self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
symbols = '-+,.;?!!!'
self.bad_words_ids = []
for k, v in self.tokenizer.vocab.items():
if k.replace('Ġ', '').lower() not in positive_words and k not in symbols:
self.bad_words_ids.append([v])
else:
# print(k)
pass
if k in positive_words:
self.logits_bias[0, v] = 0
# t11 = self.tokenizer(',', return_tensors="np")
# t198 = self.tokenizer('\n', return_tensors="np")
# eos = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
self.model.eval()
@ -65,6 +64,17 @@ class FooocusExpansion:
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
@torch.no_grad()
@torch.inference_mode()
def logits_processor(self, input_ids, scores):
assert scores.ndim == 2 and scores.shape[0] == 1
bias = self.logits_bias.to(scores).clone()
bias[0, input_ids[0].to(bias.device).long()] = neg_inf
bias[0, 11] = 0
return scores + bias
@torch.no_grad()
@torch.inference_mode()
def __call__(self, prompt, seed):
if prompt == '':
return ''
@ -91,7 +101,7 @@ class FooocusExpansion:
top_k=100,
max_new_tokens=max_new_tokens,
do_sample=True,
bad_words_ids=self.bad_words_ids)
logits_processor=LogitsProcessorList([self.logits_processor]))
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
result = safe_str(response[0])