mirror of https://github.com/google/gemma.cpp.git
[WIP] quality tweaks - for constants, defer float cast and use double for intermediate computations, add `model` to EOT token
This commit is contained in:
parent
5b9d8a9936
commit
0f6a4b49d5
6
gemma.cc
6
gemma.cc
|
|
@ -295,7 +295,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(kQKVDim));
|
||||
static const float kQueryScale = static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
|
|
@ -418,7 +418,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
hwy::ThreadPool& inner_pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
|
||||
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -473,7 +473,7 @@ void Transformer(int token, size_t pos,
|
|||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
||||
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
|
||||
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data(), kModelDim);
|
||||
|
|
|
|||
2
run.cc
2
run.cc
|
|
@ -186,7 +186,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
if (abs_pos > 0) {
|
||||
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
||||
// continuation.
|
||||
prompt_string = "<end_of_turn>\n" + prompt_string;
|
||||
prompt_string = "<end_of_turn>model\n" + prompt_string;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue