From 1964aec7f8ddd87d55b722bf0d7a121d1bcc0f9e Mon Sep 17 00:00:00 2001 From: lvmin Date: Wed, 13 Sep 2023 18:33:14 -0700 Subject: [PATCH] use SOTA sampling for GPT2 --- modules/expansion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/expansion.py b/modules/expansion.py index 2efe8816..3dc1b8b8 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -50,9 +50,11 @@ class FooocusExpansion: tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device) tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device) - features = self.model.generate(**tokenized_kwargs, num_beams=5, do_sample=True, max_new_tokens=256) - response = self.tokenizer.batch_decode(features, skip_special_tokens=True) + # https://huggingface.co/blog/introducing-csearch + # https://huggingface.co/docs/transformers/generation_strategies + features = self.model.generate(**tokenized_kwargs, penalty_alpha=0.8, top_k=8, max_new_tokens=256) + response = self.tokenizer.batch_decode(features, skip_special_tokens=True) result = response[0][len(origin):] result = safe_str(result) result = remove_pattern(result, dangrous_patterns)