minor revision
Now the GPT will try to (1) use more aligned formatting of commas, (2) a bit more dynamic in word choice, (3) avoid duplication like “detail, detail, detail, detail”
This commit is contained in:
parent
5dc1221c65
commit
01bfa11fd1
|
|
@ -1 +1 @@
|
|||
version = '2.1.771'
|
||||
version = '2.1.772'
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import torch
|
|||
import math
|
||||
import fcbh.model_management as model_management
|
||||
|
||||
from transformers.generation.logits_process import LogitsProcessorList
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||
from modules.path import fooocus_expansion_path
|
||||
from fcbh.model_patcher import ModelPatcher
|
||||
|
|
@ -10,6 +11,7 @@ from fcbh.model_patcher import ModelPatcher
|
|||
|
||||
# limitation of np.random.seed(), called from transformers.set_seed()
|
||||
SEED_LIMIT_NUMPY = 2**32
|
||||
neg_inf = - 8192.0
|
||||
|
||||
|
||||
def safe_str(x):
|
||||
|
|
@ -31,20 +33,17 @@ class FooocusExpansion:
|
|||
|
||||
positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'),
|
||||
encoding='utf-8').read().splitlines()
|
||||
positive_words = [x.lower() for x in positive_words if x != '']
|
||||
positive_words = ['Ġ' + x for x in positive_words if x != '']
|
||||
|
||||
# new_content = '\n'.join(sorted(list(set(positive_words))))
|
||||
# eos = self.tokenizer.eos_token_id
|
||||
self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
|
||||
|
||||
symbols = '-+,.;?!!!'
|
||||
|
||||
self.bad_words_ids = []
|
||||
for k, v in self.tokenizer.vocab.items():
|
||||
if k.replace('Ġ', '').lower() not in positive_words and k not in symbols:
|
||||
self.bad_words_ids.append([v])
|
||||
else:
|
||||
# print(k)
|
||||
pass
|
||||
if k in positive_words:
|
||||
self.logits_bias[0, v] = 0
|
||||
|
||||
# t11 = self.tokenizer(',', return_tensors="np")
|
||||
# t198 = self.tokenizer('\n', return_tensors="np")
|
||||
# eos = self.tokenizer.eos_token_id
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
|
||||
self.model.eval()
|
||||
|
|
@ -65,6 +64,17 @@ class FooocusExpansion:
|
|||
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
|
||||
print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def logits_processor(self, input_ids, scores):
|
||||
assert scores.ndim == 2 and scores.shape[0] == 1
|
||||
bias = self.logits_bias.to(scores).clone()
|
||||
bias[0, input_ids[0].to(bias.device).long()] = neg_inf
|
||||
bias[0, 11] = 0
|
||||
return scores + bias
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def __call__(self, prompt, seed):
|
||||
if prompt == '':
|
||||
return ''
|
||||
|
|
@ -91,7 +101,7 @@ class FooocusExpansion:
|
|||
top_k=100,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
bad_words_ids=self.bad_words_ids)
|
||||
logits_processor=LogitsProcessorList([self.logits_processor]))
|
||||
|
||||
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
|
||||
result = safe_str(response[0])
|
||||
|
|
|
|||
Loading…
Reference in New Issue