mirror of https://github.com/google/gemma.cpp.git
Batch inference fixes: set pos during prefill, fix assert
PiperOrigin-RevId: 772458760
This commit is contained in:
parent
d342e4e7d4
commit
f2adbfbcab
|
|
@ -49,7 +49,8 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference)
|
const InferenceArgs& inference)
|
||||||
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
|
: env_(MakeMatMulEnv(threading, inference)),
|
||||||
|
gemma_(loader, inference, env_) {
|
||||||
const ModelConfig& config = gemma_.GetModelConfig();
|
const ModelConfig& config = gemma_.GetModelConfig();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference));
|
kv_caches_.push_back(KVCache(config, inference));
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference));
|
||||||
gcpp::Gemma gemma(loader, inference, env);
|
gcpp::Gemma gemma(loader, inference, env);
|
||||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
|
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class SimplifiedGemma {
|
||||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||||
: env_(MakeMatMulEnv(threading)),
|
: env_(MakeMatMulEnv(threading, inference)),
|
||||||
gemma_(loader, inference, env_),
|
gemma_(loader, inference, env_),
|
||||||
kv_cache_(gemma_.GetModelConfig(), inference) {
|
kv_cache_(gemma_.GetModelConfig(), inference) {
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
int max_generated_tokens)
|
int max_generated_tokens)
|
||||||
: inference_args(inference_args),
|
: inference_args(inference_args),
|
||||||
threading_args(threading_args),
|
threading_args(threading_args),
|
||||||
matmul_env(MakeMatMulEnv(threading_args)),
|
matmul_env(MakeMatMulEnv(threading_args, inference_args)),
|
||||||
active_conversation_name("default"),
|
active_conversation_name("default"),
|
||||||
model(loader, inference_args, matmul_env) {
|
model(loader, inference_args, matmul_env) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,6 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr size_t kVocabSize = 256000;
|
static constexpr size_t kVocabSize = 256000;
|
||||||
static constexpr size_t kMaxSeqLen = 4096;
|
|
||||||
|
|
||||||
static ModelConfig ConfigNoSSM() {
|
static ModelConfig ConfigNoSSM() {
|
||||||
ModelConfig config;
|
ModelConfig config;
|
||||||
|
|
|
||||||
|
|
@ -344,8 +344,9 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
token = qbatch.Prompt(qi)[pos_in_prompt];
|
token = qbatch.Prompt(qi)[pos_in_prompt];
|
||||||
// Ignore StreamToken return value because requesting to stop does not
|
// Ignore StreamToken return value because requesting to stop does not
|
||||||
// make sense during prefill.
|
// make sense during prefill.
|
||||||
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi),
|
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
|
||||||
token, 0.0f);
|
token, 0.0f);
|
||||||
|
qbatch.MutablePos(qi) = pos_in_prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
qbatch.PrevToken(qi) = token;
|
qbatch.PrevToken(qi) = token;
|
||||||
|
|
@ -356,6 +357,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
// probabilities, which are not required for the prompt tokens.
|
// probabilities, which are not required for the prompt tokens.
|
||||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||||
|
|
@ -457,7 +462,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
|
|
||||||
size_t max_prompt_size = 0;
|
size_t max_prompt_size = 0;
|
||||||
bool all_prefix_end_are_zero = true;
|
bool all_prefix_end_are_zero = true;
|
||||||
size_t prefill_tokens = 0; // only for timing.
|
size_t total_prefill_tokens = 0; // only for throughput stats.
|
||||||
const size_t seq_len = qbatch.KV(0).SeqLen();
|
const size_t seq_len = qbatch.KV(0).SeqLen();
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const PromptTokens& prompt = qbatch.Prompt(qi);
|
const PromptTokens& prompt = qbatch.Prompt(qi);
|
||||||
|
|
@ -465,7 +470,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
|
|
||||||
// Prefill stops before size - 1 because the last prompt token is the
|
// Prefill stops before size - 1 because the last prompt token is the
|
||||||
// first input token for generation.
|
// first input token for generation.
|
||||||
prefill_tokens += prompt.size() - 1;
|
total_prefill_tokens += prompt.size() - 1;
|
||||||
|
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||||
|
|
@ -475,7 +480,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
// We use a single divisor, so all sequence lengths must be the same.
|
// We use a single divisor, so all sequence lengths must be the same.
|
||||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||||
}
|
}
|
||||||
HWY_ASSERT(prefill_tokens < seq_len);
|
HWY_ASSERT(max_prompt_size < seq_len);
|
||||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||||
|
|
||||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||||
|
|
@ -494,7 +499,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
|
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
|
||||||
}
|
}
|
||||||
HWY_DASSERT(non_eos.Count() == qbatch.Size());
|
HWY_DASSERT(non_eos.Count() == qbatch.Size());
|
||||||
timing_info.NotifyPrefill(prefill_tokens);
|
timing_info.NotifyPrefill(total_prefill_tokens);
|
||||||
// queries_pos have been incremented by Prefill.
|
// queries_pos have been incremented by Prefill.
|
||||||
|
|
||||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||||
|
|
@ -505,10 +510,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
if (prefill_tokens + max_gen_steps > seq_len) {
|
if (max_prompt_size + max_gen_steps > seq_len) {
|
||||||
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||||
prefill_tokens, max_gen_steps, seq_len);
|
max_prompt_size, max_gen_steps, seq_len);
|
||||||
max_gen_steps = seq_len - prefill_tokens;
|
max_gen_steps = seq_len - max_prompt_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||||
|
|
@ -588,8 +593,16 @@ HWY_EXPORT(GenerateSingleT);
|
||||||
HWY_EXPORT(GenerateBatchT);
|
HWY_EXPORT(GenerateBatchT);
|
||||||
HWY_EXPORT(GenerateImageTokensT);
|
HWY_EXPORT(GenerateImageTokensT);
|
||||||
|
|
||||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
|
||||||
|
const InferenceArgs& inference_args) {
|
||||||
|
if (inference_args.decode_qbatch_size >= 256) {
|
||||||
|
ThreadingArgs copy = threading_args;
|
||||||
|
copy.max_packages = 1;
|
||||||
|
ThreadingContext::SetArgs(copy);
|
||||||
|
} else {
|
||||||
ThreadingContext::SetArgs(threading_args);
|
ThreadingContext::SetArgs(threading_args);
|
||||||
|
}
|
||||||
|
|
||||||
return MatMulEnv(ThreadingContext::Get());
|
return MatMulEnv(ThreadingContext::Get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -226,7 +226,8 @@ struct TimingInfo {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the `MatMulEnv` after calling `SetArgs`.
|
// Returns the `MatMulEnv` after calling `SetArgs`.
|
||||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
|
||||||
|
const InferenceArgs& inference_args);
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
|
|
|
||||||
|
|
@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference) {
|
const InferenceArgs& inference) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
MatMulEnv env(MakeMatMulEnv(threading));
|
MatMulEnv env(MakeMatMulEnv(threading, inference));
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, env);
|
const Gemma gemma(loader, inference, env);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), inference);
|
KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||||
|
|
|
||||||
|
|
@ -61,20 +61,20 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||||
visitor(skip_packages, "skip_packages", size_t{0},
|
visitor(skip_packages, "skip_packages", size_t{0},
|
||||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_packages, "max_packages", size_t{0},
|
visitor(max_packages, "max_packages", size_t{0},
|
||||||
"Maximum number of sockets to use; default 0 = unlimited.", 2);
|
"Max sockets to use; default 0 = all unless large batch size.", 2);
|
||||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
visitor(skip_clusters, "skip_clusters", size_t{0},
|
||||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_clusters, "max_clusters", size_t{0},
|
visitor(max_clusters, "max_clusters", size_t{0},
|
||||||
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
|
"Max CCXs to use; default 0 = unlimited.", 2);
|
||||||
// These are only used when CPU topology is unknown.
|
// These are only used when CPU topology is unknown.
|
||||||
visitor(skip_lps, "skip_lps", size_t{0},
|
visitor(skip_lps, "skip_lps", size_t{0},
|
||||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
"Index of the first LP to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_lps, "max_lps", size_t{0},
|
visitor(max_lps, "max_lps", size_t{0},
|
||||||
"Maximum number of LPs to use; default 0 = unlimited.", 2);
|
"Max LPs to use; default 0 = unlimited.", 2);
|
||||||
|
|
||||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||||
visitor(max_threads, "num_threads", size_t{0},
|
visitor(max_threads, "num_threads", size_t{0},
|
||||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
"Max threads to use; default 0 = unlimited.", 2);
|
||||||
visitor(pin, "pin", Tristate::kDefault,
|
visitor(pin, "pin", Tristate::kDefault,
|
||||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||||
visitor(spin, "spin", Tristate::kDefault,
|
visitor(spin, "spin", Tristate::kDefault,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue