adjust context length (#795)
* adjust context length * Update sdxl_styles_fooocus.json
This commit is contained in:
parent
823fa924d3
commit
a16b451fd7
|
|
@ -1 +1 @@
|
|||
version = '2.1.747'
|
||||
version = '2.1.748'
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import math
|
||||
import fcbh.model_management as model_management
|
||||
|
||||
from transformers.generation.logits_process import LogitsProcessorList
|
||||
|
|
@ -91,7 +92,8 @@ class FooocusExpansion:
|
|||
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
|
||||
|
||||
current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
|
||||
max_token_length = 77 + 77 * int(float(current_token_length) / 77.0)
|
||||
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
|
||||
max_new_tokens = max_token_length - current_token_length
|
||||
|
||||
logits_processor = LogitsProcessorList([self.logits_processor])
|
||||
|
||||
|
|
@ -99,7 +101,7 @@ class FooocusExpansion:
|
|||
# https://huggingface.co/docs/transformers/generation_strategies
|
||||
features = self.model.generate(**tokenized_kwargs,
|
||||
num_beams=1,
|
||||
max_new_tokens=max_token_length - current_token_length,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
logits_processor=logits_processor)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue