mirror of https://github.com/google/gemma.cpp.git
[WIP] quality tweaks - for constants, defer float cast and use double for intermediate computations, add `model` to EOT token
This commit is contained in:
parent
5b9d8a9936
commit
0f6a4b49d5
6
gemma.cc
6
gemma.cc
|
|
@ -295,7 +295,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
static constexpr size_t kModelDim =
|
static constexpr size_t kModelDim =
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(kQKVDim));
|
static const float kQueryScale = static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
// linear projections to QKV
|
// linear projections to QKV
|
||||||
|
|
@ -418,7 +418,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
hwy::ThreadPool& inner_pool) {
|
hwy::ThreadPool& inner_pool) {
|
||||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
|
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||||
|
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
@ -473,7 +473,7 @@ void Transformer(int token, size_t pos,
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
static constexpr size_t kLayers = TConfig::kLayers;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
||||||
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
|
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||||
|
|
||||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
|
|
|
||||||
2
run.cc
2
run.cc
|
|
@ -186,7 +186,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
if (abs_pos > 0) {
|
if (abs_pos > 0) {
|
||||||
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
||||||
// continuation.
|
// continuation.
|
||||||
prompt_string = "<end_of_turn>\n" + prompt_string;
|
prompt_string = "<end_of_turn>model\n" + prompt_string;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue