From 959f965b77ff38d33390e8ccbaa691d9fb6c3736 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 25 Oct 2023 17:03:15 -0700 Subject: [PATCH] improve gpt2 improve gpt2 --- fooocus_version.py | 2 +- modules/expansion.py | 33 ++++++++++++++++++++++++++++++--- update_log.md | 4 ++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index 831e5c92..be90b2f6 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.742' +version = '2.1.743' diff --git a/modules/expansion.py b/modules/expansion.py index f65089fa..23660544 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -1,7 +1,7 @@ import torch - import fcbh.model_management as model_management +from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from modules.path import fooocus_expansion_path from fcbh.model_patcher import ModelPatcher @@ -16,6 +16,14 @@ fooocus_magic_split = [ ] dangrous_patterns = '[]【】()()|::' +black_list = ['art', 'digital', 'Ġpaint', 'painting', 'drawing', 'draw', 'drawn', + 'concept', 'illustration', 'illustrated', 'illustrate', + 'face', 'eye', 'eyes', 'hand', 'hands', + 'monster', 'artistic', 'oil', 'brush', + 'artwork', 'artworks'] + +black_list += ['Ġ' + k for k in black_list] + def safe_str(x): x = str(x) @@ -33,6 +41,15 @@ def remove_pattern(x, pattern): class FooocusExpansion: def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path) + self.vocab = self.tokenizer.vocab + self.logits_bias = torch.zeros((1, len(self.vocab)), dtype=torch.float32) + self.logits_bias[0, self.tokenizer.eos_token_id] = - 16.0 + # test_198 = self.tokenizer('\n', return_tensors="pt") + self.logits_bias[0, 198] = - 1024.0 + for k, v in self.vocab.items(): + if k in black_list: + self.logits_bias[0, v] = - 1024.0 + self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path) self.model.eval() @@ -52,6 +69,10 @@ class FooocusExpansion: self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device) print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.') + def logits_processor(self, input_ids, scores): + self.logits_bias = self.logits_bias.to(scores) + return scores + self.logits_bias + def __call__(self, prompt, seed): if self.patcher.current_device != self.patcher.load_device: print('Fooocus Expansion loaded by itself.') @@ -66,12 +87,18 @@ class FooocusExpansion: tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device) tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device) + current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1]) + max_token_length = 77 + 77 * int(float(current_token_length) / 77.0) + + logits_processor = LogitsProcessorList([self.logits_processor]) + # https://huggingface.co/blog/introducing-csearch # https://huggingface.co/docs/transformers/generation_strategies features = self.model.generate(**tokenized_kwargs, num_beams=1, - max_new_tokens=256, - do_sample=True) + max_new_tokens=max_token_length - current_token_length, + do_sample=True, + logits_processor=logits_processor) response = self.tokenizer.batch_decode(features, skip_special_tokens=True) result = response[0][len(origin):] diff --git a/update_log.md b/update_log.md index df7e83ad..61e13ebd 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +# 2.1.743 + +* Improved GPT2 by removing some tokens that may corrupt styles. + # 2.1.741 Style Updates: