Batch inference fixes: set pos during prefill, fix assert

PiperOrigin-RevId: 772458760
This commit is contained in:
Jan Wassenberg 2025-06-17 07:09:00 -07:00 committed by Copybara-Service
parent d342e4e7d4
commit f2adbfbcab
9 changed files with 35 additions and 21 deletions

View File

@ -49,7 +49,8 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
: env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference));

View File

@ -51,7 +51,7 @@ int main(int argc, char** argv) {
}
// Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
gcpp::Gemma gemma(loader, inference, env);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
size_t generated = 0;

View File

@ -35,7 +35,7 @@ class SimplifiedGemma {
SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: env_(MakeMatMulEnv(threading)),
: env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_),
kv_cache_(gemma_.GetModelConfig(), inference) {
// Initialize random number generator

View File

@ -101,7 +101,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
int max_generated_tokens)
: inference_args(inference_args),
threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args)),
matmul_env(MakeMatMulEnv(threading_args, inference_args)),
active_conversation_name("default"),
model(loader, inference_args, matmul_env) {
std::stringstream ss;

View File

@ -28,7 +28,6 @@
namespace gcpp {
static constexpr size_t kVocabSize = 256000;
static constexpr size_t kMaxSeqLen = 4096;
static ModelConfig ConfigNoSSM() {
ModelConfig config;

View File

@ -344,8 +344,9 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
token = qbatch.Prompt(qi)[pos_in_prompt];
// Ignore StreamToken return value because requesting to stop does not
// make sense during prefill.
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi),
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
token, 0.0f);
qbatch.MutablePos(qi) = pos_in_prompt;
}
qbatch.PrevToken(qi) = token;
@ -356,6 +357,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
// probabilities, which are not required for the prompt tokens.
Transformer(config, runtime_config, weights, activations, qbatch, env);
}
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1;
}
}
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
@ -457,7 +462,7 @@ static void GenerateT(const ModelConfig& config,
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t prefill_tokens = 0; // only for timing.
size_t total_prefill_tokens = 0; // only for throughput stats.
const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = qbatch.Prompt(qi);
@ -465,7 +470,7 @@ static void GenerateT(const ModelConfig& config,
// Prefill stops before size - 1 because the last prompt token is the
// first input token for generation.
prefill_tokens += prompt.size() - 1;
total_prefill_tokens += prompt.size() - 1;
// Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
@ -475,7 +480,7 @@ static void GenerateT(const ModelConfig& config,
// We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
HWY_ASSERT(prefill_tokens < seq_len);
HWY_ASSERT(max_prompt_size < seq_len);
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
@ -494,7 +499,7 @@ static void GenerateT(const ModelConfig& config,
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
}
HWY_DASSERT(non_eos.Count() == qbatch.Size());
timing_info.NotifyPrefill(prefill_tokens);
timing_info.NotifyPrefill(total_prefill_tokens);
// queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens.
@ -505,10 +510,10 @@ static void GenerateT(const ModelConfig& config,
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
if (prefill_tokens + max_gen_steps > seq_len) {
if (max_prompt_size + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
prefill_tokens, max_gen_steps, seq_len);
max_gen_steps = seq_len - prefill_tokens;
max_prompt_size, max_gen_steps, seq_len);
max_gen_steps = seq_len - max_prompt_size;
}
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
@ -588,8 +593,16 @@ HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT);
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args) {
if (inference_args.decode_qbatch_size >= 256) {
ThreadingArgs copy = threading_args;
copy.max_packages = 1;
ThreadingContext::SetArgs(copy);
} else {
ThreadingContext::SetArgs(threading_args);
}
return MatMulEnv(ThreadingContext::Get());
}

View File

@ -226,7 +226,8 @@ struct TimingInfo {
};
// Returns the `MatMulEnv` after calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args);
class Gemma {
public:

View File

@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
PROFILER_ZONE("Run.misc");
MatMulEnv env(MakeMatMulEnv(threading));
MatMulEnv env(MakeMatMulEnv(threading, inference));
if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, env);
KVCache kv_cache(gemma.GetModelConfig(), inference);

View File

@ -61,20 +61,20 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{0},
"Maximum number of sockets to use; default 0 = unlimited.", 2);
"Max sockets to use; default 0 = all unless large batch size.", 2);
visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
"Max CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0},
"Maximum number of LPs to use; default 0 = unlimited.", 2);
"Max LPs to use; default 0 = unlimited.", 2);
// The exact meaning is more subtle: see the comment at NestedPools ctor.
visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2);
"Max threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault,