From 01bfa11fd1730c63ee79473ba6a6033220117dd2 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 1 Nov 2023 12:50:12 -0700 Subject: [PATCH] minor revision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now the GPT will try to (1) use more aligned formatting of commas, (2) a bit more dynamic in word choice, (3) avoid duplication like “detail, detail, detail, detail” --- fooocus_version.py | 2 +- modules/expansion.py | 34 ++++++++++++++++++++++------------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index cb54aeb5..200610cb 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.771' +version = '2.1.772' diff --git a/modules/expansion.py b/modules/expansion.py index 2478410c..a536fad0 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -3,6 +3,7 @@ import torch import math import fcbh.model_management as model_management +from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from modules.path import fooocus_expansion_path from fcbh.model_patcher import ModelPatcher @@ -10,6 +11,7 @@ from fcbh.model_patcher import ModelPatcher # limitation of np.random.seed(), called from transformers.set_seed() SEED_LIMIT_NUMPY = 2**32 +neg_inf = - 8192.0 def safe_str(x): @@ -31,20 +33,17 @@ class FooocusExpansion: positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'), encoding='utf-8').read().splitlines() - positive_words = [x.lower() for x in positive_words if x != ''] + positive_words = ['Ġ' + x for x in positive_words if x != ''] - # new_content = '\n'.join(sorted(list(set(positive_words)))) - # eos = self.tokenizer.eos_token_id + self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf - symbols = '-+,.;?!!!' - - self.bad_words_ids = [] for k, v in self.tokenizer.vocab.items(): - if k.replace('Ġ', '').lower() not in positive_words and k not in symbols: - self.bad_words_ids.append([v]) - else: - # print(k) - pass + if k in positive_words: + self.logits_bias[0, v] = 0 + + # t11 = self.tokenizer(',', return_tensors="np") + # t198 = self.tokenizer('\n', return_tensors="np") + # eos = self.tokenizer.eos_token_id self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path) self.model.eval() @@ -65,6 +64,17 @@ class FooocusExpansion: self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device) print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.') + @torch.no_grad() + @torch.inference_mode() + def logits_processor(self, input_ids, scores): + assert scores.ndim == 2 and scores.shape[0] == 1 + bias = self.logits_bias.to(scores).clone() + bias[0, input_ids[0].to(bias.device).long()] = neg_inf + bias[0, 11] = 0 + return scores + bias + + @torch.no_grad() + @torch.inference_mode() def __call__(self, prompt, seed): if prompt == '': return '' @@ -91,7 +101,7 @@ class FooocusExpansion: top_k=100, max_new_tokens=max_new_tokens, do_sample=True, - bad_words_ids=self.bad_words_ids) + logits_processor=LogitsProcessorList([self.logits_processor])) response = self.tokenizer.batch_decode(features, skip_special_tokens=True) result = safe_str(response[0])