Batch inference fixes: set pos during prefill, fix assert

PiperOrigin-RevId: 772458760
This commit is contained in:
Jan Wassenberg 2025-06-17 07:09:00 -07:00 committed by Copybara-Service
parent d342e4e7d4
commit f2adbfbcab
9 changed files with 35 additions and 21 deletions

View File

@ -49,7 +49,8 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) { : env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_) {
const ModelConfig& config = gemma_.GetModelConfig(); const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference)); kv_caches_.push_back(KVCache(config, inference));

View File

@ -51,7 +51,7 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading)); gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
gcpp::Gemma gemma(loader, inference, env); gcpp::Gemma gemma(loader, inference, env);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference); gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
size_t generated = 0; size_t generated = 0;

View File

@ -35,7 +35,7 @@ class SimplifiedGemma {
SimplifiedGemma(const gcpp::LoaderArgs& loader, SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: env_(MakeMatMulEnv(threading)), : env_(MakeMatMulEnv(threading, inference)),
gemma_(loader, inference, env_), gemma_(loader, inference, env_),
kv_cache_(gemma_.GetModelConfig(), inference) { kv_cache_(gemma_.GetModelConfig(), inference) {
// Initialize random number generator // Initialize random number generator

View File

@ -101,7 +101,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
int max_generated_tokens) int max_generated_tokens)
: inference_args(inference_args), : inference_args(inference_args),
threading_args(threading_args), threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args)), matmul_env(MakeMatMulEnv(threading_args, inference_args)),
active_conversation_name("default"), active_conversation_name("default"),
model(loader, inference_args, matmul_env) { model(loader, inference_args, matmul_env) {
std::stringstream ss; std::stringstream ss;

View File

@ -28,7 +28,6 @@
namespace gcpp { namespace gcpp {
static constexpr size_t kVocabSize = 256000; static constexpr size_t kVocabSize = 256000;
static constexpr size_t kMaxSeqLen = 4096;
static ModelConfig ConfigNoSSM() { static ModelConfig ConfigNoSSM() {
ModelConfig config; ModelConfig config;

View File

@ -344,8 +344,9 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
token = qbatch.Prompt(qi)[pos_in_prompt]; token = qbatch.Prompt(qi)[pos_in_prompt];
// Ignore StreamToken return value because requesting to stop does not // Ignore StreamToken return value because requesting to stop does not
// make sense during prefill. // make sense during prefill.
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi), (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
token, 0.0f); token, 0.0f);
qbatch.MutablePos(qi) = pos_in_prompt;
} }
qbatch.PrevToken(qi) = token; qbatch.PrevToken(qi) = token;
@ -356,6 +357,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
// probabilities, which are not required for the prompt tokens. // probabilities, which are not required for the prompt tokens.
Transformer(config, runtime_config, weights, activations, qbatch, env); Transformer(config, runtime_config, weights, activations, qbatch, env);
} }
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1;
}
} }
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
@ -457,7 +462,7 @@ static void GenerateT(const ModelConfig& config,
size_t max_prompt_size = 0; size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true; bool all_prefix_end_are_zero = true;
size_t prefill_tokens = 0; // only for timing. size_t total_prefill_tokens = 0; // only for throughput stats.
const size_t seq_len = qbatch.KV(0).SeqLen(); const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = qbatch.Prompt(qi); const PromptTokens& prompt = qbatch.Prompt(qi);
@ -465,7 +470,7 @@ static void GenerateT(const ModelConfig& config,
// Prefill stops before size - 1 because the last prompt token is the // Prefill stops before size - 1 because the last prompt token is the
// first input token for generation. // first input token for generation.
prefill_tokens += prompt.size() - 1; total_prefill_tokens += prompt.size() - 1;
// Sanity check: prompts should not be empty, nor start with EOS. // Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
@ -475,7 +480,7 @@ static void GenerateT(const ModelConfig& config,
// We use a single divisor, so all sequence lengths must be the same. // We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
} }
HWY_ASSERT(prefill_tokens < seq_len); HWY_ASSERT(max_prompt_size < seq_len);
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have // Lacks a constructor to bulk-set, hence initialized by Prefill* which have
@ -494,7 +499,7 @@ static void GenerateT(const ModelConfig& config,
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch. activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
} }
HWY_DASSERT(non_eos.Count() == qbatch.Size()); HWY_DASSERT(non_eos.Count() == qbatch.Size());
timing_info.NotifyPrefill(prefill_tokens); timing_info.NotifyPrefill(total_prefill_tokens);
// queries_pos have been incremented by Prefill. // queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens. // Stream the last prompt token from each query, fill activations.gen_tokens.
@ -505,10 +510,10 @@ static void GenerateT(const ModelConfig& config,
} }
size_t max_gen_steps = runtime_config.max_generated_tokens; size_t max_gen_steps = runtime_config.max_generated_tokens;
if (prefill_tokens + max_gen_steps > seq_len) { if (max_prompt_size + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
prefill_tokens, max_gen_steps, seq_len); max_prompt_size, max_gen_steps, seq_len);
max_gen_steps = seq_len - prefill_tokens; max_gen_steps = seq_len - max_prompt_size;
} }
const SampleFunc sample_token = ChooseSampleFunc(runtime_config); const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
@ -588,8 +593,16 @@ HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT); HWY_EXPORT(GenerateImageTokensT);
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args) {
if (inference_args.decode_qbatch_size >= 256) {
ThreadingArgs copy = threading_args;
copy.max_packages = 1;
ThreadingContext::SetArgs(copy);
} else {
ThreadingContext::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
}
return MatMulEnv(ThreadingContext::Get()); return MatMulEnv(ThreadingContext::Get());
} }

View File

@ -226,7 +226,8 @@ struct TimingInfo {
}; };
// Returns the `MatMulEnv` after calling `SetArgs`. // Returns the `MatMulEnv` after calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args); MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args);
class Gemma { class Gemma {
public: public:

View File

@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) { const InferenceArgs& inference) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
MatMulEnv env(MakeMatMulEnv(threading)); MatMulEnv env(MakeMatMulEnv(threading, inference));
if (inference.verbosity >= 2) env.print_best = true; if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, env); const Gemma gemma(loader, inference, env);
KVCache kv_cache(gemma.GetModelConfig(), inference); KVCache kv_cache(gemma.GetModelConfig(), inference);

View File

@ -61,20 +61,20 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
visitor(skip_packages, "skip_packages", size_t{0}, visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2); "Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{0}, visitor(max_packages, "max_packages", size_t{0},
"Maximum number of sockets to use; default 0 = unlimited.", 2); "Max sockets to use; default 0 = all unless large batch size.", 2);
visitor(skip_clusters, "skip_clusters", size_t{0}, visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2); "Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0}, visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of CCXs to use; default 0 = unlimited.", 2); "Max CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown. // These are only used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0}, visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2); "Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0}, visitor(max_lps, "max_lps", size_t{0},
"Maximum number of LPs to use; default 0 = unlimited.", 2); "Max LPs to use; default 0 = unlimited.", 2);
// The exact meaning is more subtle: see the comment at NestedPools ctor. // The exact meaning is more subtle: see the comment at NestedPools ctor.
visitor(max_threads, "num_threads", size_t{0}, visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2); "Max threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", Tristate::kDefault, visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault, visitor(spin, "spin", Tristate::kDefault,