From bb767d788d458199dfdb50ca6890ad03a4e3adce Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 27 Mar 2024 07:48:30 -0700 Subject: [PATCH] Bounds-checks for large prompts. Refs #99 Also remove init placeholder and move Sqrt to ops.h. PiperOrigin-RevId: 619529202 --- gemma.cc | 107 ++++++++++++++++++++++++++++--------------------------- ops.h | 16 +++++++++ 2 files changed, 71 insertions(+), 52 deletions(-) diff --git a/gemma.cc b/gemma.cc index 27e8921..23fc31f 100644 --- a/gemma.cc +++ b/gemma.cc @@ -52,8 +52,6 @@ #include #include -// Placeholder for internal header, do not modify. - // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // copybara:import_next_line:gemma_cpp @@ -466,18 +464,6 @@ HWY_NOINLINE void FFW(Activations& activations, activations.ffw_out.data() + batch_idx * kModelDim, pool); } -// __builtin_sqrt is not constexpr as of Clang 17. -#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \ - HWY_HAVE_SCALAR_BF16_OPERATORS -#define GEMMA_CONSTEXPR_SQRT constexpr -static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { - return __builtin_sqrt(x); -} -#else -#define GEMMA_CONSTEXPR_SQRT -static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } -#endif - template GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() { // Round to bf16 to match Gemma's Embedder, which casts before mul. @@ -569,6 +555,31 @@ void Transformer(int token, size_t pos, kModelDim); } +template +void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, + size_t& prompt_size) { + if (max_tokens > TConfig::kSeqLen) { + fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", + max_tokens, TConfig::kSeqLen); + max_tokens = static_cast(TConfig::kSeqLen); + } + + if (max_generated_tokens > max_tokens) { + fprintf(stderr, + "WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n", + max_generated_tokens, max_tokens); + max_generated_tokens = max_tokens - 1; + } + + if (prompt_size + max_generated_tokens > max_tokens) { + fprintf(stderr, + "WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen " + "%d, truncating.\n", + prompt_size, max_generated_tokens, TConfig::kSeqLen); + prompt_size = max_tokens - max_generated_tokens; + } +} + template void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, @@ -577,16 +588,21 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kTopK = TConfig::kTopK; Activations& activations = *gemma.state.get(); Activations& prefill_activations = *gemma.prefill.get(); const CompressedWeights& c_weights = *reinterpret_cast*>( gemma.compressed_weights.get()); - int token; + + size_t prompt_size = prompt.size(); + RangeChecks(max_tokens, max_generated_tokens, prompt_size); + if (pos >= max_tokens) { + fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, + max_tokens); + return; + } // pos indexes the KV cache. In the first turn of a chat, pos = 0. // @@ -602,21 +618,22 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t pos_offset = 0; // offset relative to pos const double prefill_start = hwy::platform::Now(); - // Prefill stops before prompt.size() - 1 since the last prompt token is the + // Prefill stops before prompt_size - 1 since the last prompt token is the // first input token for generation. - while (pos_offset < prompt.size() - 1) { - const size_t end_offset = - std::min(kPrefillBatchSize, prompt.size() - 1 - pos_offset); - HWY_DASSERT(end_offset < prompt.size()); + while (pos_offset < prompt_size - 1) { + const size_t batch_size = + std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset); + HWY_DASSERT(batch_size <= kPrefillBatchSize); + HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); const int* batch_tokens = prompt.data() + pos_offset; - Prefill(batch_tokens, end_offset, pos, + Prefill(batch_tokens, batch_size, pos, c_weights, prefill_activations, kv_cache, pool, inner_pool); - for (size_t idx = 0; idx < end_offset; ++idx) { - stream_token(batch_tokens[idx], 0.0); + for (size_t idx = 0; idx < batch_size; ++idx) { + stream_token(batch_tokens[idx], 0.0f); } - pos += end_offset; - pos_offset += end_offset; + pos += batch_size; + pos_offset += batch_size; } if (verbosity >= 2) { @@ -630,39 +647,25 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, const double gen_start = hwy::platform::Now(); - HWY_DASSERT(pos_offset == prompt.size() - 1); + HWY_DASSERT(pos_offset == prompt_size - 1); - if (verbosity >= 2) { - // Provide usage warnings if max_new_tokens is out of range. - if (max_generated_tokens > max_tokens) { - std::cout << "Warning: max_new_tokens should be <= max_tokens" - << std::endl; - } else if ((prompt.size() + max_generated_tokens) > max_tokens) { - std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." - << std::endl; - } else if (pos >= max_tokens) { - std::cout << "Warning: pos exceeds max_tokens." - << std::endl; - } - } - - auto pos_gen_start = pos_offset; - token = prompt.at(pos_offset); - size_t generate_pos = 0; - for (; pos < max_tokens && generate_pos < max_generated_tokens; + size_t pos_gen_start = pos_offset; + int token = prompt.at(pos_offset); + for (size_t generate_pos = 0; + pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); - if (pos_offset >= prompt.size()) { + if (pos_offset >= prompt_size) { PROFILER_ZONE("Gen.Embedding"); // Generation phase - MatVec(c_weights.c_embedder_input_embedding, 0, - final_activation, activations.logits.data(), - pool); + MatVec( + c_weights.c_embedder_input_embedding, 0, final_activation, + activations.logits.data(), pool); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); - token = SampleTopK(activations.logits.data(), kVocabSize, gen, - temperature, accept_token); + token = SampleTopK(activations.logits.data(), kVocabSize, + gen, temperature, accept_token); } if (!stream_token(token, activations.logits[token])) { token = EOS_ID; diff --git a/ops.h b/ops.h index 1a65b49..0959c02 100644 --- a/ops.h +++ b/ops.h @@ -22,6 +22,7 @@ #include #include #include +#include // std::enable_if_t // copybara:import_next_line:gemma_cpp #include "compression/compress.h" @@ -29,6 +30,21 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" +namespace gcpp { + +// __builtin_sqrt is not constexpr as of Clang 17. +#if HWY_COMPILER_GCC_ACTUAL +#define GEMMA_CONSTEXPR_SQRT constexpr +static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { + return __builtin_sqrt(x); +} +#else +#define GEMMA_CONSTEXPR_SQRT +static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } +#endif + +} // namespace gcpp + #endif // THIRD_PARTY_GEMMA_CPP_OPS_H_ // Include guard for (potentially) SIMD code.