Implement the Griffin model.

Also implement support for some model variations:

- Local attention.
- Add support for biases.
- Use RoPE only on half vectors.
- Support different order of QKV weights.

Co-authored-by: Andrey Mikhaylov <amik@google.com>
Co-authored-by: Martin Bruse <zondolfin@gmail.com>
Co-authored-by: Zoltan Szabadka <szabadka@google.com>
This commit is contained in:
Luca Versari 2024-04-04 15:01:53 +02:00
parent 4326249d0a
commit 9c3f969405
8 changed files with 639 additions and 136 deletions

View File

@ -241,6 +241,24 @@ Example invocation for the following configuration:
--model 2b-it
```
### RecurrentGemma
This repository includes a version of Gemma based on Griffin
([paper](https://arxiv.org/abs/2402.19427),
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
includes both recurrent layers and local attention, thus it is more efficient
for longer sequences and has a smaller memory footprint than standard Gemma. We
here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows:
`./gemma --tokenizer tokenizer.spm --model gr2b-it --compressed_weights 2b-it-sfp.sbs`
### Troubleshooting and FAQs
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
@ -478,4 +496,9 @@ gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka.
This is not an officially supported Google product.

View File

@ -10,14 +10,14 @@
#include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
using json = nlohmann::json;
@ -259,6 +259,13 @@ int main(int argc, char** argv) {
gcpp::AppArgs app(argc, argv);
BenchmarkArgs benchmark_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
@ -275,7 +282,7 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type + ".txt";
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) {

View File

@ -44,35 +44,14 @@ struct Args : public ArgsBase<Args> {
ChooseNumThreads();
}
static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc.substr(0, 2) == "2b") {
return gcpp::Model::GEMMA_2B;
} else if (model_type_lc.substr(0, 2) == "7b") {
return gcpp::Model::GEMMA_7B;
} else {
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
}
}
gcpp::Model ModelType() const { return model_type; }
// Returns error string or nullptr if OK.
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
}
const char* Validate() {
ModelTraining model_training;
const char* parse_result =
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
if (parse_result) return parse_result;
if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model.";
}
@ -88,7 +67,8 @@ struct Args : public ArgsBase<Args> {
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type;
std::string model_type_str;
Model model_type;
size_t num_threads;
template <class Visitor>
@ -96,10 +76,12 @@ struct Args : public ArgsBase<Args> {
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n"
" Required argument.");
visitor(model_type, "model", std::string(),
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights file will be written.\n"
@ -115,7 +97,7 @@ struct Args : public ArgsBase<Args> {
void ShowHelp(gcpp::Args& args) {
std::cerr
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
" --model <model type> --compressed_weights <output path>\n";
" --model <model type> --compressed_weights <output path>\n";
std::cerr << "\n*Arguments*\n\n";
args.Help();
std::cerr << "\n";

View File

@ -30,6 +30,8 @@
#include <stddef.h>
#include <array>
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t
@ -45,16 +47,41 @@ namespace gcpp {
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
enum class LayerAttentionType {
kGemma,
kGriffinRecurrentBlock,
};
template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
LayerAttentionType type) {
std::array<LayerAttentionType, kNum> config = {};
for (LayerAttentionType& l : config) {
l = type;
}
return config;
}
struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 28;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
@ -62,17 +89,79 @@ struct ConfigGemma7B {
struct ConfigGemma2B {
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 18;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
struct ConfigGriffin2B {
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
using WeightT = GEMMA_WEIGHT_T;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

519
gemma.cc
View File

@ -25,12 +25,12 @@
#include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp
#include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// copybara:end
@ -43,11 +43,12 @@
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdlib>
#include <filesystem> // NOLINT
#include <iostream>
#include <memory>
@ -74,8 +75,35 @@ constexpr bool kDryRunFread = false;
// Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true;
// Set this to true to debug tokenizer tokens.
constexpr bool kShowTokenization = false;
namespace gcpp {
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <typename TConfig>
constexpr size_t NumGemmaLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGemma, TConfig::kLayers);
}
template <typename TConfig>
constexpr size_t NumGriffinLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
TConfig::kLayers);
}
template <class TConfig>
struct Layer {
Layer() = default;
@ -96,6 +124,25 @@ struct Layer {
std::array<float, kModelDim * kFFHiddenDim> linear_w;
std::array<float, kModelDim> pre_attention_norm_scale;
std::array<float, kModelDim> pre_ffw_norm_scale;
// These fields are only used by Griffin, and do not affect loading of the
// model as it is done per-member.
// TODO(veluca): pull weights that are never used at the same time into a
// union or otherwise reduce the memory usage.
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
std::array<float, kModelDim> ffw_output_biases;
std::array<float, kModelDim> attention_output_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
std::array<float, kModelDim> griffin_linear_y_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
std::array<float, kModelDim> griffin_linear_x_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
std::array<float, kModelDim> griffin_linear_out_biases;
std::array<float, kModelDim> griffin_conv_biases;
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
std::array<float, kModelDim * 2> griffin_gate_biases;
std::array<float, kModelDim> griffin_a;
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
};
float ScaleWeights(float* data, size_t len) {
@ -196,6 +243,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
Layer<TConfig>* layer_view = weights->GetLayer(layer);
#define READ_WEIGHTS(name) \
@ -212,16 +260,42 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
} while (0)
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin_linear_x_w);
READ_WEIGHTS(griffin_linear_x_biases);
READ_WEIGHTS(griffin_linear_y_w);
READ_WEIGHTS(griffin_linear_y_biases);
READ_WEIGHTS(griffin_linear_out_w);
READ_WEIGHTS(griffin_linear_out_biases);
READ_WEIGHTS(griffin_conv_w);
READ_WEIGHTS(griffin_conv_biases);
READ_WEIGHTS(griffin_gate_w);
READ_WEIGHTS(griffin_gate_biases);
READ_WEIGHTS(griffin_a);
SCALE_WEIGHTS(griffin_linear_x_w);
SCALE_WEIGHTS(griffin_linear_y_w);
SCALE_WEIGHTS(griffin_linear_out_w);
SCALE_WEIGHTS(griffin_gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
#undef READ_WEIGHTS
}
if (!ok) {
@ -253,14 +327,30 @@ struct CompressedLayer {
// We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
CompressedArray<float, kModelDim> ffw_output_biases;
CompressedArray<float, kModelDim> attention_output_biases;
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
griffin_gate_w;
CompressedArray<float, kModelDim> griffin_a;
CompressedArray<float, kModelDim> griffin_linear_y_biases;
CompressedArray<float, kModelDim> griffin_linear_x_biases;
CompressedArray<float, kModelDim> griffin_linear_out_biases;
CompressedArray<float, kModelDim> griffin_conv_biases;
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
};
// Array instead of single large allocation for parallel mem init. Split out of
// CompressedWeights so that only these pointers are initialized, not the
// Array instead of single large allocation for parallel mem init. Split out
// of CompressedWeights so that only these pointers are initialized, not the
// CompressedArray.
template <class TConfig>
struct CompressedLayerPointers {
@ -307,7 +397,8 @@ struct Activations {
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCachePosSize = TConfig::kLayers * kKVHeads * kQKVDim;
static constexpr size_t kCachePosSize =
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input
@ -327,6 +418,12 @@ struct Activations {
// bf_ffw_hidden;
std::array<float, kBatchSize * kModelDim> ffw_out;
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// Griffin layer internal activations
std::array<float, kBatchSize * kModelDim> griffin_x;
std::array<float, kBatchSize * kModelDim> griffin_y;
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
};
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
@ -353,8 +450,13 @@ struct GemmaInterface {
template <class Config>
KVCache CreateKVCache() {
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen);
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache(
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen,
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim,
NumGriffinLayers<Config>() * Config::kModelDim);
}
KVCache CreateKVCache(Model type) {
@ -363,6 +465,8 @@ KVCache CreateKVCache(Model type) {
return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>();
case Model::GRIFFIN_2B:
return CreateKVCache<ConfigGriffin2B>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
@ -379,7 +483,15 @@ class GemmaTokenizerImpl : public GemmaTokenizer {
}
bool Encode(const std::string& input,
std::vector<int>* pieces) const override {
return impl_->Encode(input, pieces).ok();
if constexpr (kShowTokenization) {
bool is_ok = impl_->Encode(input, pieces).ok();
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
}
return is_ok;
} else {
return impl_->Encode(input, pieces).ok();
}
}
// Given a sequence of ids, decodes it into a detokenized output.
bool Decode(const std::vector<int>& ids,
@ -442,6 +554,119 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
// X / Y linear layers.
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset,
/*add0=*/layer_weights->griffin_linear_x_biases.data(),
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
Gelu(y, kModelDim);
// Conv1D.
{
HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)];
cache[0] = x;
for (size_t i = 1; i < kConv1dWidth; i++) {
cache[i] =
kv_cache.conv1d_cache.get() + layer_offset +
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
}
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i);
auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i);
}
}
// RGLRU
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
constexpr size_t kHeadDim = kModelDim / kHeads;
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
layer_weights->griffin_gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim +
head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin_a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
// Final linear layer.
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_out_w, 0, x,
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool);
}
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations,
@ -462,6 +687,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
size_t cache_pos = pos;
size_t cache_num = pos + 1;
if constexpr (TConfig::kUseLocalAttention) {
cache_pos %= TConfig::kSeqLen;
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
}
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
@ -480,7 +712,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
v_offset, x, k, v);
Rope(k, kQKVDim, pos);
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
};
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
@ -491,24 +723,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim;
Rope(q, kQKVDim, pos);
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
Softmax(head_att, pos + 1);
Softmax(head_att, cache_num);
// Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
@ -520,22 +752,34 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
head == 0
? activations.att_post2.data() + batch_idx * kModelDim
: activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out,
head_out);
if (head == 0) {
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
layer_weights->attention_output_biases.data(), head_out);
} else {
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out,
head_out);
}
};
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t head_offset = head * 3 * kQKVDim * kModelDim;
// linear projections to QKV
const size_t head_offset = TConfig::kInterleaveQKV
? 3 * kQKVDim * kModelDim
: kQKVDim * kModelDim;
const size_t mat_offset =
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
const size_t q_offset = head * head_offset + 0 * mat_offset;
const size_t k_offset = head * head_offset + 1 * mat_offset;
const size_t v_offset = head * head_offset + 2 * mat_offset;
ProjQ(head, head_offset);
ProjQ(head, q_offset);
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
const size_t kv_offset =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
ProjKV(k_offset, v_offset, kv_offset);
@ -546,7 +790,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
@ -581,13 +827,13 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// Same matrix, first and second half of rows. Could fuse into one MatVec,
// but separating them could help on NUMA e.g. multiple sockets.
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w,
kFFHiddenDim * kModelDim, vec, out_mul,
pool);
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, 0, vec, out,
pool);
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), out, pool);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
@ -598,8 +844,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
}
PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVec<kModelDim, kFFHiddenDim>(
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(),
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
@ -639,15 +886,23 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
});
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim);
Attention<kBatchSize>(pos, token_idx, layer, activations, layer_weights,
kv_cache, pool);
if (type == LayerAttentionType::kGemma) {
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
} else {
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
}
}
// TODO: sink the loop into these functions, i.e. make them matmuls.
@ -678,9 +933,7 @@ template <typename WeightArrayT, class TConfig>
void Transformer(int token, size_t pos, const WeightArrayT& weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim;
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);
@ -688,12 +941,21 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNorm(activations.x.data(),
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim);
Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool);
if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
pool);
} else {
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
kv_cache, pool);
}
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim);
@ -707,10 +969,12 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
template <class TConfig>
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
size_t& prompt_size) {
if (max_tokens > TConfig::kSeqLen) {
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
max_tokens, TConfig::kSeqLen);
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
if (!TConfig::kUseLocalAttention) {
if (max_tokens > TConfig::kSeqLen) {
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
max_tokens, TConfig::kSeqLen);
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
}
if (max_generated_tokens > max_tokens) {
@ -720,12 +984,14 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
max_generated_tokens = max_tokens - 1;
}
if (prompt_size + max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
"%d, truncating.\n",
prompt_size, max_generated_tokens, TConfig::kSeqLen);
prompt_size = max_tokens - max_generated_tokens;
if (!TConfig::kUseLocalAttention) {
if (prompt_size + max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
"%d, truncating.\n",
prompt_size, max_generated_tokens, TConfig::kSeqLen);
prompt_size = max_tokens - max_generated_tokens;
}
}
}
@ -935,6 +1201,19 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
accept_token, gen, verbosity);
}
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
@ -951,6 +1230,15 @@ float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
inner_pool, verbosity);
}
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
inner_pool, verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to
// load from cache.
@ -967,6 +1255,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
@ -978,9 +1267,33 @@ void ForEachTensor(const Weights<TConfig>* weights,
CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
CALL_FUNC("gating_ein", gating_einsum_w);
CALL_FUNC("linear_w", linear_w);
CALL_FUNC("qkv_ein", qkv_einsum_w);
CALL_FUNC("att_ein", attn_vec_einsum_w);
if (type == LayerAttentionType::kGemma) {
CALL_FUNC("qkv_ein", qkv_einsum_w);
CALL_FUNC("att_ein", attn_vec_einsum_w);
} else {
CALL_FUNC("gr_lin_x_w", griffin_linear_x_w);
CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases);
CALL_FUNC("gr_lin_y_w", griffin_linear_y_w);
CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases);
CALL_FUNC("gr_lin_out_w", griffin_linear_out_w);
CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases);
CALL_FUNC("gr_conv_w", griffin_conv_w);
CALL_FUNC("gr_conv_b", griffin_conv_biases);
CALL_FUNC("gr_gate_w", griffin_gate_w);
CALL_FUNC("gr_gate_b", griffin_gate_biases);
CALL_FUNC("gr_a", griffin_a);
}
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
if (TConfig::kFFBiases) {
CALL_FUNC("ffw_gat_b", ffw_gating_biases);
CALL_FUNC("ffw_out_b", ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
CALL_FUNC("attn_ob", attention_output_biases);
}
#undef CALL_FUNC
}
}
@ -1011,10 +1324,18 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
if (TConfig::kNumTensorScales > 0) {
size_t scale_pos = 0;
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
if (type == LayerAttentionType::kGemma) {
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
} else {
layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]);
}
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->linear_w.set_scale(scales[scale_pos++]);
}
@ -1031,6 +1352,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -1044,6 +1367,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
return LoadWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadWeights<ConfigGriffin2B>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -1089,6 +1414,9 @@ void CompressWeightsT(gcpp::Model model, const Path& weights,
case Model::GEMMA_7B:
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
break;
case Model::GRIFFIN_2B:
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -1106,13 +1434,29 @@ HWY_EXPORT(LoadWeightsT);
HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
HWY_EXPORT(GenerateGriffin2B);
HWY_EXPORT(ComputeCrossEntropy2B);
HWY_EXPORT(ComputeCrossEntropy7B);
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv_cache_size, size_t rglru_cache_size) {
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
if (size_cache_pos != 0) {
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache =
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
}
if (conv_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv_cache_size);
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
conv_cache_size * sizeof(kv_cache.conv1d_cache[0]));
}
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
}
return kv_cache;
}
@ -1136,6 +1480,7 @@ void GemmaImpl<ConfigGemma2B>::Generate(
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
@ -1148,6 +1493,18 @@ void GemmaImpl<ConfigGemma7B>::Generate(
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <>
void GemmaImpl<ConfigGriffin2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <>
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
@ -1164,6 +1521,14 @@ float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
}
template <>
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
@ -1190,6 +1555,9 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
break;
case Model::GRIFFIN_2B:
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
}
@ -1240,5 +1608,42 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
return result;
}
namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
"2b-it", "7b-it", "gr2b-it"};
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024];
kErrorMessageBuffer[0] = 0;
strcat(kErrorMessageBuffer,
"Invalid or missing model flag, need to specify one of ");
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kModelFlags[i]);
strcat(kErrorMessageBuffer, ", ");
}
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]);
strcat(kErrorMessageBuffer, ".");
std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
return nullptr;
}
}
return kErrorMessageBuffer;
}
} // namespace gcpp
#endif // HWY_ONCE

18
gemma.h
View File

@ -42,15 +42,24 @@ constexpr bool kSystemPrompt = false;
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kNumGriffinLayers
};
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B };
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
@ -79,7 +88,8 @@ struct Gemma {
};
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.

12
run.cc
View File

@ -27,8 +27,6 @@
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -36,8 +34,12 @@
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
static constexpr bool kVerboseLogTokens = false;
namespace gcpp {
static constexpr std::string_view kAsciiArtBanner = R""(
@ -203,6 +205,12 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,

View File

@ -125,46 +125,21 @@ class AppArgs : public ArgsBase<AppArgs> {
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const { return model_type; }
gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
return gcpp::Model::GEMMA_2B;
} else {
return gcpp::Model::GEMMA_7B;
}
}
gcpp::ModelTraining ModelTraining() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
return gcpp::ModelTraining::GEMMA_PT;
} else {
return gcpp::ModelTraining::GEMMA_IT;
}
}
gcpp::ModelTraining ModelTraining() const { return model_training; }
// Returns error string or nullptr if OK.
const char* Validate() {
const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, or 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
"7b-it.";
}
const char* parse_result =
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
if (parse_result) return parse_result;
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (!tokenizer.exists()) {
return "Can't open file specified with --tokenizer flag.";
}
if (!compressed_weights.path.empty()) {
if (weights.path.empty()) {
weights = compressed_weights;
@ -186,7 +161,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path tokenizer;
Path weights; // weights file location
Path compressed_weights;
std::string model_type;
std::string model_type_str;
Model model_type;
enum ModelTraining model_training;
template <class Visitor>
void ForEach(const Visitor& visitor) {
@ -196,10 +173,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"Path name of model weights (.sbs) file.\n Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Alias for --weights.");
visitor(model_type, "model", std::string(),
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
}
};