mirror of https://github.com/google/gemma.cpp.git
Bounds-checks for large prompts. Refs #99
Also remove init placeholder and move Sqrt to ops.h. PiperOrigin-RevId: 619529202
This commit is contained in:
parent
bbf4df4584
commit
bb767d788d
107
gemma.cc
107
gemma.cc
|
|
@ -52,8 +52,6 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
|
|
@ -466,18 +464,6 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
}
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \
|
||||
HWY_HAVE_SCALAR_BF16_OPERATORS
|
||||
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||
return __builtin_sqrt(x);
|
||||
}
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_SQRT
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
|
|
@ -569,6 +555,31 @@ void Transformer(int token, size_t pos,
|
|||
kModelDim);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||
size_t& prompt_size) {
|
||||
if (max_tokens > TConfig::kSeqLen) {
|
||||
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
||||
max_tokens, TConfig::kSeqLen);
|
||||
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||
}
|
||||
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
fprintf(stderr,
|
||||
"WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n",
|
||||
max_generated_tokens, max_tokens);
|
||||
max_generated_tokens = max_tokens - 1;
|
||||
}
|
||||
|
||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||
fprintf(stderr,
|
||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
||||
"%d, truncating.\n",
|
||||
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
||||
prompt_size = max_tokens - max_generated_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
|
|
@ -577,16 +588,21 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kTopK = TConfig::kTopK;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
||||
*gemma.prefill.get();
|
||||
const CompressedWeights<TConfig>& c_weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
||||
gemma.compressed_weights.get());
|
||||
int token;
|
||||
|
||||
size_t prompt_size = prompt.size();
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
||||
if (pos >= max_tokens) {
|
||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
||||
max_tokens);
|
||||
return;
|
||||
}
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
//
|
||||
|
|
@ -602,21 +618,22 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
size_t pos_offset = 0; // offset relative to pos
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
|
||||
// Prefill stops before prompt.size() - 1 since the last prompt token is the
|
||||
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
||||
// first input token for generation.
|
||||
while (pos_offset < prompt.size() - 1) {
|
||||
const size_t end_offset =
|
||||
std::min(kPrefillBatchSize, prompt.size() - 1 - pos_offset);
|
||||
HWY_DASSERT(end_offset < prompt.size());
|
||||
while (pos_offset < prompt_size - 1) {
|
||||
const size_t batch_size =
|
||||
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
|
||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
||||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||
const int* batch_tokens = prompt.data() + pos_offset;
|
||||
Prefill<TConfig, kPrefillBatchSize>(batch_tokens, end_offset, pos,
|
||||
Prefill<TConfig, kPrefillBatchSize>(batch_tokens, batch_size, pos,
|
||||
c_weights, prefill_activations,
|
||||
kv_cache, pool, inner_pool);
|
||||
for (size_t idx = 0; idx < end_offset; ++idx) {
|
||||
stream_token(batch_tokens[idx], 0.0);
|
||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||
stream_token(batch_tokens[idx], 0.0f);
|
||||
}
|
||||
pos += end_offset;
|
||||
pos_offset += end_offset;
|
||||
pos += batch_size;
|
||||
pos_offset += batch_size;
|
||||
}
|
||||
|
||||
if (verbosity >= 2) {
|
||||
|
|
@ -630,39 +647,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
|
||||
const double gen_start = hwy::platform::Now();
|
||||
|
||||
HWY_DASSERT(pos_offset == prompt.size() - 1);
|
||||
HWY_DASSERT(pos_offset == prompt_size - 1);
|
||||
|
||||
if (verbosity >= 2) {
|
||||
// Provide usage warnings if max_new_tokens is out of range.
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
std::cout << "Warning: max_new_tokens should be <= max_tokens"
|
||||
<< std::endl;
|
||||
} else if ((prompt.size() + max_generated_tokens) > max_tokens) {
|
||||
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
|
||||
<< std::endl;
|
||||
} else if (pos >= max_tokens) {
|
||||
std::cout << "Warning: pos exceeds max_tokens."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
auto pos_gen_start = pos_offset;
|
||||
token = prompt.at(pos_offset);
|
||||
size_t generate_pos = 0;
|
||||
for (; pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
size_t pos_gen_start = pos_offset;
|
||||
int token = prompt.at(pos_offset);
|
||||
for (size_t generate_pos = 0;
|
||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
||||
float* final_activation = activations.x.data();
|
||||
if (pos_offset >= prompt.size()) {
|
||||
if (pos_offset >= prompt_size) {
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Generation phase
|
||||
MatVec<kVocabSize, kModelDim>(c_weights.c_embedder_input_embedding, 0,
|
||||
final_activation, activations.logits.data(),
|
||||
pool);
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
c_weights.c_embedder_input_embedding, 0, final_activation,
|
||||
activations.logits.data(), pool);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
|
||||
temperature, accept_token);
|
||||
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
||||
gen, temperature, accept_token);
|
||||
}
|
||||
if (!stream_token(token, activations.logits[token])) {
|
||||
token = EOS_ID;
|
||||
|
|
|
|||
16
ops.h
16
ops.h
|
|
@ -22,6 +22,7 @@
|
|||
#include <array>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h"
|
||||
|
|
@ -29,6 +30,21 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||
return __builtin_sqrt(x);
|
||||
}
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_SQRT
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||
#endif
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
|
|||
Loading…
Reference in New Issue