Implement the Griffin model.

Also implement support for some model variations:

- Local attention.
- Add support for biases.
- Use RoPE only on half vectors.
- Support different order of QKV weights.

Co-authored-by: Andrey Mikhaylov <amik@google.com>
Co-authored-by: Martin Bruse <zondolfin@gmail.com>
Co-authored-by: Zoltan Szabadka <szabadka@google.com>
This commit is contained in:
Luca Versari 2024-04-04 15:01:53 +02:00
parent 4326249d0a
commit 9c3f969405
8 changed files with 639 additions and 136 deletions

View File

@ -241,6 +241,24 @@ Example invocation for the following configuration:
--model 2b-it --model 2b-it
``` ```
### RecurrentGemma
This repository includes a version of Gemma based on Griffin
([paper](https://arxiv.org/abs/2402.19427),
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
includes both recurrent layers and local attention, thus it is more efficient
for longer sequences and has a smaller memory footprint than standard Gemma. We
here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows:
`./gemma --tokenizer tokenizer.spm --model gr2b-it --compressed_weights 2b-it-sfp.sbs`
### Troubleshooting and FAQs ### Troubleshooting and FAQs
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** **Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
@ -478,4 +496,9 @@ gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024 and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka.
This is not an officially supported Google product. This is not an officially supported Google product.

View File

@ -10,14 +10,14 @@
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" #include "gemma.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -259,6 +259,13 @@ int main(int argc, char** argv) {
gcpp::AppArgs app(argc, argv); gcpp::AppArgs app(argc, argv);
BenchmarkArgs benchmark_args(argc, argv); BenchmarkArgs benchmark_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
hwy::ThreadPool inner_pool(0); hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps. // For many-core, pinning threads to cores helps.
@ -275,7 +282,7 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.path.empty()) { if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path = const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type + ".txt"; benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool, return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
golden_path); golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) { } else if (!benchmark_args.summarize_text.path.empty()) {

View File

@ -44,35 +44,14 @@ struct Args : public ArgsBase<Args> {
ChooseNumThreads(); ChooseNumThreads();
} }
static std::string ToLower(const std::string& text) { gcpp::Model ModelType() const { return model_type; }
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc.substr(0, 2) == "2b") {
return gcpp::Model::GEMMA_2B;
} else if (model_type_lc.substr(0, 2) == "7b") {
return gcpp::Model::GEMMA_7B;
} else {
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
}
}
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() const { const char* Validate() {
const std::string model_type_lc = ToLower(model_type); ModelTraining model_training;
if (model_type.empty()) { const char* parse_result =
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " ParseModelTypeAndTraining(model_type_str, model_type, model_training);
"2b-it, 7b-it."; if (parse_result) return parse_result;
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
}
if (weights.path.empty()) { if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model."; return "Missing --weights flag, a file for the uncompressed model.";
} }
@ -88,7 +67,8 @@ struct Args : public ArgsBase<Args> {
Path weights; // uncompressed weights file location Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location Path compressed_weights; // compressed weights file location
std::string model_type; std::string model_type_str;
Model model_type;
size_t num_threads; size_t num_threads;
template <class Visitor> template <class Visitor>
@ -96,10 +76,12 @@ struct Args : public ArgsBase<Args> {
visitor(weights, "weights", Path(), visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n" "Path name of model weights (.sbs) file.\n"
" Required argument."); " Required argument.");
visitor(model_type, "model", std::string(), visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n " "Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument."); " Required argument.");
visitor(compressed_weights, "compressed_weights", Path(), visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights file will be written.\n" "Path name where compressed weights file will be written.\n"
@ -115,7 +97,7 @@ struct Args : public ArgsBase<Args> {
void ShowHelp(gcpp::Args& args) { void ShowHelp(gcpp::Args& args) {
std::cerr std::cerr
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> " << "Usage:\n./compress_weights --weights <path to uncompressed weights> "
" --model <model type> --compressed_weights <output path>\n"; " --model <model type> --compressed_weights <output path>\n";
std::cerr << "\n*Arguments*\n\n"; std::cerr << "\n*Arguments*\n\n";
args.Help(); args.Help();
std::cerr << "\n"; std::cerr << "\n";

View File

@ -30,6 +30,8 @@
#include <stddef.h> #include <stddef.h>
#include <array>
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
@ -45,16 +47,41 @@ namespace gcpp {
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK; static constexpr size_t kTopK = GEMMA_TOPK;
enum class LayerAttentionType {
kGemma,
kGriffinRecurrentBlock,
};
template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
LayerAttentionType type) {
std::array<LayerAttentionType, kNum> config = {};
for (LayerAttentionType& l : config) {
l = type;
}
return config;
}
struct ConfigGemma7B { struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 28; static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16; static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T; using WeightT = GEMMA_WEIGHT_T;
}; };
@ -62,17 +89,79 @@ struct ConfigGemma7B {
struct ConfigGemma2B { struct ConfigGemma2B {
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr int kLayers = 18; static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8; static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T; using WeightT = GEMMA_WEIGHT_T;
}; };
struct ConfigGriffin2B {
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
using WeightT = GEMMA_WEIGHT_T;
};
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

519
gemma.cc
View File

@ -25,12 +25,12 @@
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
// copybara:end // copybara:end
@ -43,11 +43,12 @@
#include <math.h> // sqrtf #include <math.h> // sqrtf
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <filesystem> // NOLINT #include <filesystem> // NOLINT
#include <iostream> #include <iostream>
#include <memory> #include <memory>
@ -74,8 +75,35 @@ constexpr bool kDryRunFread = false;
// Setting this to false will load and use uncompressed weights. // Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true; constexpr bool kWeightsAreCompressed = true;
// Set this to true to debug tokenizer tokens.
constexpr bool kShowTokenization = false;
namespace gcpp { namespace gcpp {
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <typename TConfig>
constexpr size_t NumGemmaLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGemma, TConfig::kLayers);
}
template <typename TConfig>
constexpr size_t NumGriffinLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
TConfig::kLayers);
}
template <class TConfig> template <class TConfig>
struct Layer { struct Layer {
Layer() = default; Layer() = default;
@ -96,6 +124,25 @@ struct Layer {
std::array<float, kModelDim * kFFHiddenDim> linear_w; std::array<float, kModelDim * kFFHiddenDim> linear_w;
std::array<float, kModelDim> pre_attention_norm_scale; std::array<float, kModelDim> pre_attention_norm_scale;
std::array<float, kModelDim> pre_ffw_norm_scale; std::array<float, kModelDim> pre_ffw_norm_scale;
// These fields are only used by Griffin, and do not affect loading of the
// model as it is done per-member.
// TODO(veluca): pull weights that are never used at the same time into a
// union or otherwise reduce the memory usage.
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
std::array<float, kModelDim> ffw_output_biases;
std::array<float, kModelDim> attention_output_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
std::array<float, kModelDim> griffin_linear_y_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
std::array<float, kModelDim> griffin_linear_x_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
std::array<float, kModelDim> griffin_linear_out_biases;
std::array<float, kModelDim> griffin_conv_biases;
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
std::array<float, kModelDim * 2> griffin_gate_biases;
std::array<float, kModelDim> griffin_a;
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
}; };
float ScaleWeights(float* data, size_t len) { float ScaleWeights(float* data, size_t len) {
@ -196,6 +243,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale", do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale)); sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
Layer<TConfig>* layer_view = weights->GetLayer(layer); Layer<TConfig>* layer_view = weights->GetLayer(layer);
#define READ_WEIGHTS(name) \ #define READ_WEIGHTS(name) \
@ -212,16 +260,42 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
} while (0) } while (0)
// Make sure we don't have uninitialized memory. // Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view)); hwy::ZeroBytes(layer_view, sizeof(*layer_view));
READ_WEIGHTS(attn_vec_einsum_w); if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(qkv_einsum_w); READ_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w); READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w); SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin_linear_x_w);
READ_WEIGHTS(griffin_linear_x_biases);
READ_WEIGHTS(griffin_linear_y_w);
READ_WEIGHTS(griffin_linear_y_biases);
READ_WEIGHTS(griffin_linear_out_w);
READ_WEIGHTS(griffin_linear_out_biases);
READ_WEIGHTS(griffin_conv_w);
READ_WEIGHTS(griffin_conv_biases);
READ_WEIGHTS(griffin_gate_w);
READ_WEIGHTS(griffin_gate_biases);
READ_WEIGHTS(griffin_a);
SCALE_WEIGHTS(griffin_linear_x_w);
SCALE_WEIGHTS(griffin_linear_y_w);
SCALE_WEIGHTS(griffin_linear_out_w);
SCALE_WEIGHTS(griffin_gate_w);
}
READ_WEIGHTS(gating_einsum_w); READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w); READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w); SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w); SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale); READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale); READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
#undef READ_WEIGHTS #undef READ_WEIGHTS
} }
if (!ok) { if (!ok) {
@ -253,14 +327,30 @@ struct CompressedLayer {
// We don't yet have an RMSNorm that accepts all WeightT. // We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale; CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale; CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
CompressedArray<float, kModelDim> ffw_output_biases;
CompressedArray<float, kModelDim> attention_output_biases;
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w; CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w; CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w; CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w; CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
griffin_gate_w;
CompressedArray<float, kModelDim> griffin_a;
CompressedArray<float, kModelDim> griffin_linear_y_biases;
CompressedArray<float, kModelDim> griffin_linear_x_biases;
CompressedArray<float, kModelDim> griffin_linear_out_biases;
CompressedArray<float, kModelDim> griffin_conv_biases;
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
}; };
// Array instead of single large allocation for parallel mem init. Split out of // Array instead of single large allocation for parallel mem init. Split out
// CompressedWeights so that only these pointers are initialized, not the // of CompressedWeights so that only these pointers are initialized, not the
// CompressedArray. // CompressedArray.
template <class TConfig> template <class TConfig>
struct CompressedLayerPointers { struct CompressedLayerPointers {
@ -307,7 +397,8 @@ struct Activations {
static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCachePosSize = TConfig::kLayers * kKVHeads * kQKVDim; static constexpr size_t kCachePosSize =
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input std::array<float, kBatchSize * kModelDim> x; // input
@ -327,6 +418,12 @@ struct Activations {
// bf_ffw_hidden; // bf_ffw_hidden;
std::array<float, kBatchSize * kModelDim> ffw_out; std::array<float, kBatchSize * kModelDim> ffw_out;
std::array<float, kBatchSize * TConfig::kVocabSize> logits; std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// Griffin layer internal activations
std::array<float, kBatchSize * kModelDim> griffin_x;
std::array<float, kBatchSize * kModelDim> griffin_y;
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
}; };
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
@ -353,8 +450,13 @@ struct GemmaInterface {
template <class Config> template <class Config>
KVCache CreateKVCache() { KVCache CreateKVCache() {
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, constexpr size_t kConv1dWidth = Config::kConv1dWidth;
Config::kSeqLen); return CreateKVCache(
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen,
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim,
NumGriffinLayers<Config>() * Config::kModelDim);
} }
KVCache CreateKVCache(Model type) { KVCache CreateKVCache(Model type) {
@ -363,6 +465,8 @@ KVCache CreateKVCache(Model type) {
return CreateKVCache<ConfigGemma2B>(); return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B: case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>(); return CreateKVCache<ConfigGemma7B>();
case Model::GRIFFIN_2B:
return CreateKVCache<ConfigGriffin2B>();
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
} }
@ -379,7 +483,15 @@ class GemmaTokenizerImpl : public GemmaTokenizer {
} }
bool Encode(const std::string& input, bool Encode(const std::string& input,
std::vector<int>* pieces) const override { std::vector<int>* pieces) const override {
return impl_->Encode(input, pieces).ok(); if constexpr (kShowTokenization) {
bool is_ok = impl_->Encode(input, pieces).ok();
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
}
return is_ok;
} else {
return impl_->Encode(input, pieces).ok();
}
} }
// Given a sequence of ids, decodes it into a detokenized output. // Given a sequence of ids, decodes it into a detokenized output.
bool Decode(const std::vector<int>& ids, bool Decode(const std::vector<int>& ids,
@ -442,6 +554,119 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
// X / Y linear layers.
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset,
/*add0=*/layer_weights->griffin_linear_x_biases.data(),
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
Gelu(y, kModelDim);
// Conv1D.
{
HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)];
cache[0] = x;
for (size_t i = 1; i < kConv1dWidth; i++) {
cache[i] =
kv_cache.conv1d_cache.get() + layer_offset +
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
}
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i);
auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i);
}
}
// RGLRU
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
constexpr size_t kHeadDim = kModelDim / kHeads;
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
layer_weights->griffin_gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim +
head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin_a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
// Final linear layer.
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_out_w, 0, x,
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool);
}
template <size_t kBatchSize, typename LayerT, class TConfig> template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
@ -462,6 +687,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
static const float kQueryScale = static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
size_t cache_pos = pos;
size_t cache_num = pos + 1;
if constexpr (TConfig::kUseLocalAttention) {
cache_pos %= TConfig::kSeqLen;
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
}
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
@ -480,7 +712,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset, TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
v_offset, x, k, v); v_offset, x, k, v);
Rope(k, kQKVDim, pos); Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}; };
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
@ -491,24 +723,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
head * TConfig::kSeqLen + head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kQKVDim;
Rope(q, kQKVDim, pos); Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score; head_att[pos2] = score;
} }
Softmax(head_att, pos + 1); Softmax(head_att, cache_num);
// Weighted summation // Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset = const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
@ -520,22 +752,34 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
head == 0 head == 0
? activations.att_post2.data() + batch_idx * kModelDim ? activations.att_post2.data() + batch_idx * kModelDim
: activations.att_post1.data() + head * kBatchSize * kModelDim; : activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w, if (head == 0) {
head * kModelDim * kQKVDim, att_out, MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
head_out); layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
layer_weights->attention_output_biases.data(), head_out);
} else {
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out,
head_out);
}
}; };
if constexpr (kHeads == kKVHeads) { if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention // Multi-Head Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t head_offset = head * 3 * kQKVDim * kModelDim; // linear projections to QKV
const size_t head_offset = TConfig::kInterleaveQKV
? 3 * kQKVDim * kModelDim
: kQKVDim * kModelDim;
const size_t mat_offset =
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
const size_t q_offset = head * head_offset + 0 * mat_offset;
const size_t k_offset = head * head_offset + 1 * mat_offset;
const size_t v_offset = head * head_offset + 2 * mat_offset;
ProjQ(head, head_offset); ProjQ(head, q_offset);
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
const size_t kv_offset = const size_t kv_offset =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
ProjKV(k_offset, v_offset, kv_offset); ProjKV(k_offset, v_offset, kv_offset);
@ -546,7 +790,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset); ProjKV(k_offset, v_offset, kv_offset);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
@ -581,13 +827,13 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// Same matrix, first and second half of rows. Could fuse into one MatVec, // Same matrix, first and second half of rows. Could fuse into one MatVec,
// but separating them could help on NUMA e.g. multiple sockets. // but separating them could help on NUMA e.g. multiple sockets.
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
kFFHiddenDim * kModelDim, vec, out_mul, layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
pool); layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool);
// Gate, will go through the nonlinearity. // Gate, will go through the nonlinearity.
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, 0, vec, out, MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
pool); layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), out, pool);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -598,8 +844,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
} }
PROFILER_ZONE("Gen.FFW\\GatedGELU"); PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVec<kModelDim, kFFHiddenDim>( MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(),
activations.ffw_out.data() + batch_idx * kModelDim, pool); activations.ffw_out.data() + batch_idx * kModelDim, pool);
} }
@ -639,15 +886,23 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
}); });
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations.x.data() + token_idx * kModelDim, RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim, activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
Attention<kBatchSize>(pos, token_idx, layer, activations, layer_weights, if (type == LayerAttentionType::kGemma) {
kv_cache, pool); Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
} else {
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
}
} }
// TODO: sink the loop into these functions, i.e. make them matmuls. // TODO: sink the loop into these functions, i.e. make them matmuls.
@ -678,9 +933,7 @@ template <typename WeightArrayT, class TConfig>
void Transformer(int token, size_t pos, const WeightArrayT& weights, void Transformer(int token, size_t pos, const WeightArrayT& weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache, Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
Decompress(weights.embedder_input_embedding, token * kModelDim, Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
@ -688,12 +941,21 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
EmbeddingScaling<TConfig>(); EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim); MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNorm(activations.x.data(), RMSNorm(activations.x.data(),
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim); activations.pre_att_rms_out.data(), kModelDim);
Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool); if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
pool);
} else {
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
kv_cache, pool);
}
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(), kModelDim);
@ -707,10 +969,12 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
template <class TConfig> template <class TConfig>
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
size_t& prompt_size) { size_t& prompt_size) {
if (max_tokens > TConfig::kSeqLen) { if (!TConfig::kUseLocalAttention) {
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", if (max_tokens > TConfig::kSeqLen) {
max_tokens, TConfig::kSeqLen); fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
max_tokens = static_cast<size_t>(TConfig::kSeqLen); max_tokens, TConfig::kSeqLen);
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
} }
if (max_generated_tokens > max_tokens) { if (max_generated_tokens > max_tokens) {
@ -720,12 +984,14 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
max_generated_tokens = max_tokens - 1; max_generated_tokens = max_tokens - 1;
} }
if (prompt_size + max_generated_tokens > max_tokens) { if (!TConfig::kUseLocalAttention) {
fprintf(stderr, if (prompt_size + max_generated_tokens > max_tokens) {
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen " fprintf(stderr,
"%d, truncating.\n", "WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
prompt_size, max_generated_tokens, TConfig::kSeqLen); "%d, truncating.\n",
prompt_size = max_tokens - max_generated_tokens; prompt_size, max_generated_tokens, TConfig::kSeqLen);
prompt_size = max_tokens - max_generated_tokens;
}
} }
} }
@ -935,6 +1201,19 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
accept_token, gen, verbosity); accept_token, gen, verbosity);
} }
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
@ -951,6 +1230,15 @@ float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
inner_pool, verbosity); inner_pool, verbosity);
} }
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
inner_pool, verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null // Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to // if weights = null, which happens during the first call where we attempt to
// load from cache. // load from cache.
@ -967,6 +1255,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr; const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx); CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
@ -978,9 +1267,33 @@ void ForEachTensor(const Weights<TConfig>* weights,
CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale); CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
CALL_FUNC("gating_ein", gating_einsum_w); CALL_FUNC("gating_ein", gating_einsum_w);
CALL_FUNC("linear_w", linear_w); CALL_FUNC("linear_w", linear_w);
CALL_FUNC("qkv_ein", qkv_einsum_w); if (type == LayerAttentionType::kGemma) {
CALL_FUNC("att_ein", attn_vec_einsum_w); CALL_FUNC("qkv_ein", qkv_einsum_w);
CALL_FUNC("att_ein", attn_vec_einsum_w);
} else {
CALL_FUNC("gr_lin_x_w", griffin_linear_x_w);
CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases);
CALL_FUNC("gr_lin_y_w", griffin_linear_y_w);
CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases);
CALL_FUNC("gr_lin_out_w", griffin_linear_out_w);
CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases);
CALL_FUNC("gr_conv_w", griffin_conv_w);
CALL_FUNC("gr_conv_b", griffin_conv_biases);
CALL_FUNC("gr_gate_w", griffin_gate_w);
CALL_FUNC("gr_gate_b", griffin_gate_biases);
CALL_FUNC("gr_a", griffin_a);
}
CALL_FUNC("pre_att_ns", pre_attention_norm_scale); CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
if (TConfig::kFFBiases) {
CALL_FUNC("ffw_gat_b", ffw_gating_biases);
CALL_FUNC("ffw_out_b", ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
CALL_FUNC("attn_ob", attention_output_biases);
}
#undef CALL_FUNC #undef CALL_FUNC
} }
} }
@ -1011,10 +1324,18 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
if (TConfig::kNumTensorScales > 0) { if (TConfig::kNumTensorScales > 0) {
size_t scale_pos = 0; size_t scale_pos = 0;
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx); CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); if (type == LayerAttentionType::kGemma) {
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
} else {
layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]);
}
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]); layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->linear_w.set_scale(scales[scale_pos++]); layer_weights->linear_w.set_scale(scales[scale_pos++]);
} }
@ -1031,6 +1352,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
return LoadCompressedWeights<ConfigGemma2B>(weights, pool); return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B: case Model::GEMMA_7B:
return LoadCompressedWeights<ConfigGemma7B>(weights, pool); return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }
@ -1044,6 +1367,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
return LoadWeights<ConfigGemma2B>(weights, pool); return LoadWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B: case Model::GEMMA_7B:
return LoadWeights<ConfigGemma7B>(weights, pool); return LoadWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadWeights<ConfigGriffin2B>(weights, pool);
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }
@ -1089,6 +1414,9 @@ void CompressWeightsT(gcpp::Model model, const Path& weights,
case Model::GEMMA_7B: case Model::GEMMA_7B:
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool); CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
break; break;
case Model::GRIFFIN_2B:
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
break;
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }
@ -1106,13 +1434,29 @@ HWY_EXPORT(LoadWeightsT);
HWY_EXPORT(CompressWeightsT); HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(Generate2B); HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B); HWY_EXPORT(Generate7B);
HWY_EXPORT(GenerateGriffin2B);
HWY_EXPORT(ComputeCrossEntropy2B); HWY_EXPORT(ComputeCrossEntropy2B);
HWY_EXPORT(ComputeCrossEntropy7B); HWY_EXPORT(ComputeCrossEntropy7B);
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv_cache_size, size_t rglru_cache_size) {
KVCache kv_cache = {}; KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); if (size_cache_pos != 0) {
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache =
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
}
if (conv_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv_cache_size);
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
conv_cache_size * sizeof(kv_cache.conv1d_cache[0]));
}
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
}
return kv_cache; return kv_cache;
} }
@ -1136,6 +1480,7 @@ void GemmaImpl<ConfigGemma2B>::Generate(
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
template <> template <>
void GemmaImpl<ConfigGemma7B>::Generate( void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature, size_t max_tokens, size_t max_generated_tokens, float temperature,
@ -1148,6 +1493,18 @@ void GemmaImpl<ConfigGemma7B>::Generate(
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
template <>
void GemmaImpl<ConfigGriffin2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <> template <>
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy( float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache, size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
@ -1164,6 +1521,14 @@ float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
} }
template <>
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
@ -1190,6 +1555,9 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
case Model::GEMMA_7B: case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool)); impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
break; break;
case Model::GRIFFIN_2B:
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
break;
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
} }
@ -1240,5 +1608,42 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
return result; return result;
} }
namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
"2b-it", "7b-it", "gr2b-it"};
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024];
kErrorMessageBuffer[0] = 0;
strcat(kErrorMessageBuffer,
"Invalid or missing model flag, need to specify one of ");
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kModelFlags[i]);
strcat(kErrorMessageBuffer, ", ");
}
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]);
strcat(kErrorMessageBuffer, ".");
std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
return nullptr;
}
}
return kErrorMessageBuffer;
}
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE #endif // HWY_ONCE

18
gemma.h
View File

@ -42,15 +42,24 @@ constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kNumGriffinLayers
}; };
// Model variants: see configs.h for details. // Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B }; enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
struct RuntimeConfig { struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
@ -79,7 +88,8 @@ struct Gemma {
}; };
KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size);
// StreamFunc is called with (token, probability). For prompt tokens, // StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. // probability is 0.0f.

12
run.cc
View File

@ -27,8 +27,6 @@
#include "compression/compress.h" #include "compression/compress.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma #include "gemma.h" // Gemma
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -36,8 +34,12 @@
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp #include "util/args.h" // HasHelp
static constexpr bool kVerboseLogTokens = false;
namespace gcpp { namespace gcpp {
static constexpr std::string_view kAsciiArtBanner = R""( static constexpr std::string_view kAsciiArtBanner = R""(
@ -203,6 +205,12 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
std::cerr << "\n" std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush; << "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,

View File

@ -125,46 +125,21 @@ class AppArgs : public ArgsBase<AppArgs> {
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string& text) { gcpp::Model ModelType() const { return model_type; }
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const { gcpp::ModelTraining ModelTraining() const { return model_training; }
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
return gcpp::Model::GEMMA_2B;
} else {
return gcpp::Model::GEMMA_7B;
}
}
gcpp::ModelTraining ModelTraining() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
return gcpp::ModelTraining::GEMMA_PT;
} else {
return gcpp::ModelTraining::GEMMA_IT;
}
}
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
const std::string model_type_lc = ToLower(model_type); const char* parse_result =
if (model_type.empty()) { ParseModelTypeAndTraining(model_type_str, model_type, model_training);
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " if (parse_result) return parse_result;
"2b-it, or 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
"7b-it.";
}
if (tokenizer.path.empty()) { if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required."; return "Missing --tokenizer flag, a file for the tokenizer is required.";
} }
if (!tokenizer.exists()) {
return "Can't open file specified with --tokenizer flag.";
}
if (!compressed_weights.path.empty()) { if (!compressed_weights.path.empty()) {
if (weights.path.empty()) { if (weights.path.empty()) {
weights = compressed_weights; weights = compressed_weights;
@ -186,7 +161,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path tokenizer; Path tokenizer;
Path weights; // weights file location Path weights; // weights file location
Path compressed_weights; Path compressed_weights;
std::string model_type; std::string model_type_str;
Model model_type;
enum ModelTraining model_training;
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
@ -196,10 +173,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"Path name of model weights (.sbs) file.\n Required argument."); "Path name of model weights (.sbs) file.\n Required argument.");
visitor(compressed_weights, "compressed_weights", Path(), visitor(compressed_weights, "compressed_weights", Path(),
"Alias for --weights."); "Alias for --weights.");
visitor(model_type, "model", std::string(), visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n " "Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument."); " Required argument.");
} }
}; };