mirror of https://github.com/google/gemma.cpp.git
Implement the Griffin model.
Also implement support for some model variations: - Local attention. - Add support for biases. - Use RoPE only on half vectors. - Support different order of QKV weights. Co-authored-by: Andrey Mikhaylov <amik@google.com> Co-authored-by: Martin Bruse <zondolfin@gmail.com> Co-authored-by: Zoltan Szabadka <szabadka@google.com>
This commit is contained in:
parent
4326249d0a
commit
9c3f969405
23
README.md
23
README.md
|
|
@ -241,6 +241,24 @@ Example invocation for the following configuration:
|
|||
--model 2b-it
|
||||
```
|
||||
|
||||
### RecurrentGemma
|
||||
|
||||
This repository includes a version of Gemma based on Griffin
|
||||
([paper](https://arxiv.org/abs/2402.19427),
|
||||
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
|
||||
includes both recurrent layers and local attention, thus it is more efficient
|
||||
for longer sequences and has a smaller memory footprint than standard Gemma. We
|
||||
here provide a C++ implementation of this model based on the paper.
|
||||
|
||||
To use the recurrent version of Gemma included in this repository, build the
|
||||
gemma binary as noted above in Step 3. Download the compressed weights and
|
||||
tokenizer from
|
||||
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
||||
Step 1, and run the binary as follows:
|
||||
|
||||
`./gemma --tokenizer tokenizer.spm --model gr2b-it --compressed_weights 2b-it-sfp.sbs`
|
||||
|
||||
|
||||
### Troubleshooting and FAQs
|
||||
|
||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||
|
|
@ -478,4 +496,9 @@ gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.
|
|||
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
||||
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
|
||||
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||
Fischbacher and Zoltan Szabadka.
|
||||
|
||||
This is not an officially supported Google product.
|
||||
|
|
|
|||
17
benchmark.cc
17
benchmark.cc
|
|
@ -10,14 +10,14 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/timer.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
|
@ -259,6 +259,13 @@ int main(int argc, char** argv) {
|
|||
gcpp::AppArgs app(argc, argv);
|
||||
BenchmarkArgs benchmark_args(argc, argv);
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
HWY_ABORT("\nInvalid loader args: %s", error);
|
||||
}
|
||||
if (const char* error = args.Validate()) {
|
||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||
}
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
|
|
@ -275,7 +282,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
if (!benchmark_args.goldens.path.empty()) {
|
||||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" + loader.model_type + ".txt";
|
||||
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
||||
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
||||
golden_path);
|
||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||
|
|
|
|||
|
|
@ -44,35 +44,14 @@ struct Args : public ArgsBase<Args> {
|
|||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return result;
|
||||
}
|
||||
|
||||
gcpp::Model ModelType() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc.substr(0, 2) == "2b") {
|
||||
return gcpp::Model::GEMMA_2B;
|
||||
} else if (model_type_lc.substr(0, 2) == "7b") {
|
||||
return gcpp::Model::GEMMA_7B;
|
||||
} else {
|
||||
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
|
||||
}
|
||||
}
|
||||
gcpp::Model ModelType() const { return model_type; }
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type.empty()) {
|
||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||
"2b-it, 7b-it.";
|
||||
}
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
|
||||
}
|
||||
const char* Validate() {
|
||||
ModelTraining model_training;
|
||||
const char* parse_result =
|
||||
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
||||
if (parse_result) return parse_result;
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the uncompressed model.";
|
||||
}
|
||||
|
|
@ -88,7 +67,8 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
Path weights; // uncompressed weights file location
|
||||
Path compressed_weights; // compressed weights file location
|
||||
std::string model_type;
|
||||
std::string model_type_str;
|
||||
Model model_type;
|
||||
size_t num_threads;
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -96,10 +76,12 @@ struct Args : public ArgsBase<Args> {
|
|||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n"
|
||||
" Required argument.");
|
||||
visitor(model_type, "model", std::string(),
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Path name where compressed weights file will be written.\n"
|
||||
|
|
@ -115,7 +97,7 @@ struct Args : public ArgsBase<Args> {
|
|||
void ShowHelp(gcpp::Args& args) {
|
||||
std::cerr
|
||||
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
||||
" --model <model type> --compressed_weights <output path>\n";
|
||||
" --model <model type> --compressed_weights <output path>\n";
|
||||
std::cerr << "\n*Arguments*\n\n";
|
||||
args.Help();
|
||||
std::cerr << "\n";
|
||||
|
|
|
|||
93
configs.h
93
configs.h
|
|
@ -30,6 +30,8 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
|
@ -45,16 +47,41 @@ namespace gcpp {
|
|||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
|
||||
enum class LayerAttentionType {
|
||||
kGemma,
|
||||
kGriffinRecurrentBlock,
|
||||
};
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
||||
LayerAttentionType type) {
|
||||
std::array<LayerAttentionType, kNum> config = {};
|
||||
for (LayerAttentionType& l : config) {
|
||||
l = type;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
struct ConfigGemma7B {
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr int kLayers = 28;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
|
@ -62,17 +89,79 @@ struct ConfigGemma7B {
|
|||
struct ConfigGemma2B {
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr int kLayers = 18;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
struct ConfigGriffin2B {
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
static constexpr int kSeqLen = 2048;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
};
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kModelDim = 2560;
|
||||
static constexpr int kFFHiddenDim = 7680;
|
||||
static constexpr int kHeads = 10;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
static constexpr bool kFFBiases = true;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = true;
|
||||
static constexpr bool kUseHalfRope = true;
|
||||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||
|
|
|
|||
519
gemma.cc
519
gemma.cc
|
|
@ -25,12 +25,12 @@
|
|||
#include "compression/compress-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "ops.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // Path
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
// copybara:end
|
||||
|
|
@ -43,11 +43,12 @@
|
|||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <filesystem> // NOLINT
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
|
@ -74,8 +75,35 @@ constexpr bool kDryRunFread = false;
|
|||
// Setting this to false will load and use uncompressed weights.
|
||||
constexpr bool kWeightsAreCompressed = true;
|
||||
|
||||
// Set this to true to debug tokenizer tokens.
|
||||
constexpr bool kShowTokenization = false;
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t NumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
LayerAttentionType type, size_t num) {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
if (layers[i] == type) count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
template <typename TConfig>
|
||||
constexpr size_t NumGemmaLayers() {
|
||||
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
|
||||
LayerAttentionType::kGemma, TConfig::kLayers);
|
||||
}
|
||||
|
||||
template <typename TConfig>
|
||||
constexpr size_t NumGriffinLayers() {
|
||||
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
TConfig::kLayers);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
struct Layer {
|
||||
Layer() = default;
|
||||
|
|
@ -96,6 +124,25 @@ struct Layer {
|
|||
std::array<float, kModelDim * kFFHiddenDim> linear_w;
|
||||
std::array<float, kModelDim> pre_attention_norm_scale;
|
||||
std::array<float, kModelDim> pre_ffw_norm_scale;
|
||||
// These fields are only used by Griffin, and do not affect loading of the
|
||||
// model as it is done per-member.
|
||||
// TODO(veluca): pull weights that are never used at the same time into a
|
||||
// union or otherwise reduce the memory usage.
|
||||
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
|
||||
std::array<float, kModelDim> ffw_output_biases;
|
||||
std::array<float, kModelDim> attention_output_biases;
|
||||
|
||||
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
|
||||
std::array<float, kModelDim> griffin_linear_y_biases;
|
||||
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
|
||||
std::array<float, kModelDim> griffin_linear_x_biases;
|
||||
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
|
||||
std::array<float, kModelDim> griffin_linear_out_biases;
|
||||
std::array<float, kModelDim> griffin_conv_biases;
|
||||
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
|
||||
std::array<float, kModelDim * 2> griffin_gate_biases;
|
||||
std::array<float, kModelDim> griffin_a;
|
||||
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
|
||||
};
|
||||
|
||||
float ScaleWeights(float* data, size_t len) {
|
||||
|
|
@ -196,6 +243,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
||||
sizeof(weights->final_norm_scale));
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
Layer<TConfig>* layer_view = weights->GetLayer(layer);
|
||||
|
||||
#define READ_WEIGHTS(name) \
|
||||
|
|
@ -212,16 +260,42 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
} while (0)
|
||||
// Make sure we don't have uninitialized memory.
|
||||
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
||||
READ_WEIGHTS(attn_vec_einsum_w);
|
||||
READ_WEIGHTS(qkv_einsum_w);
|
||||
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||
SCALE_WEIGHTS(qkv_einsum_w);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
READ_WEIGHTS(attn_vec_einsum_w);
|
||||
READ_WEIGHTS(qkv_einsum_w);
|
||||
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||
SCALE_WEIGHTS(qkv_einsum_w);
|
||||
} else {
|
||||
READ_WEIGHTS(griffin_linear_x_w);
|
||||
READ_WEIGHTS(griffin_linear_x_biases);
|
||||
READ_WEIGHTS(griffin_linear_y_w);
|
||||
READ_WEIGHTS(griffin_linear_y_biases);
|
||||
READ_WEIGHTS(griffin_linear_out_w);
|
||||
READ_WEIGHTS(griffin_linear_out_biases);
|
||||
READ_WEIGHTS(griffin_conv_w);
|
||||
READ_WEIGHTS(griffin_conv_biases);
|
||||
READ_WEIGHTS(griffin_gate_w);
|
||||
READ_WEIGHTS(griffin_gate_biases);
|
||||
READ_WEIGHTS(griffin_a);
|
||||
SCALE_WEIGHTS(griffin_linear_x_w);
|
||||
SCALE_WEIGHTS(griffin_linear_y_w);
|
||||
SCALE_WEIGHTS(griffin_linear_out_w);
|
||||
SCALE_WEIGHTS(griffin_gate_w);
|
||||
}
|
||||
READ_WEIGHTS(gating_einsum_w);
|
||||
READ_WEIGHTS(linear_w);
|
||||
SCALE_WEIGHTS(gating_einsum_w);
|
||||
SCALE_WEIGHTS(linear_w);
|
||||
READ_WEIGHTS(pre_attention_norm_scale);
|
||||
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||
if (TConfig::kFFBiases) {
|
||||
READ_WEIGHTS(ffw_gating_biases);
|
||||
READ_WEIGHTS(ffw_output_biases);
|
||||
}
|
||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||
type == LayerAttentionType::kGemma) {
|
||||
READ_WEIGHTS(attention_output_biases);
|
||||
}
|
||||
#undef READ_WEIGHTS
|
||||
}
|
||||
if (!ok) {
|
||||
|
|
@ -253,14 +327,30 @@ struct CompressedLayer {
|
|||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
|
||||
CompressedArray<float, kModelDim> ffw_output_biases;
|
||||
CompressedArray<float, kModelDim> attention_output_biases;
|
||||
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
|
||||
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
|
||||
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
|
||||
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
|
||||
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
|
||||
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
|
||||
griffin_gate_w;
|
||||
CompressedArray<float, kModelDim> griffin_a;
|
||||
CompressedArray<float, kModelDim> griffin_linear_y_biases;
|
||||
CompressedArray<float, kModelDim> griffin_linear_x_biases;
|
||||
CompressedArray<float, kModelDim> griffin_linear_out_biases;
|
||||
CompressedArray<float, kModelDim> griffin_conv_biases;
|
||||
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
|
||||
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
|
||||
};
|
||||
|
||||
// Array instead of single large allocation for parallel mem init. Split out of
|
||||
// CompressedWeights so that only these pointers are initialized, not the
|
||||
// Array instead of single large allocation for parallel mem init. Split out
|
||||
// of CompressedWeights so that only these pointers are initialized, not the
|
||||
// CompressedArray.
|
||||
template <class TConfig>
|
||||
struct CompressedLayerPointers {
|
||||
|
|
@ -307,7 +397,8 @@ struct Activations {
|
|||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kCachePosSize = TConfig::kLayers * kKVHeads * kQKVDim;
|
||||
static constexpr size_t kCachePosSize =
|
||||
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim;
|
||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
|
||||
|
||||
std::array<float, kBatchSize * kModelDim> x; // input
|
||||
|
|
@ -327,6 +418,12 @@ struct Activations {
|
|||
// bf_ffw_hidden;
|
||||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
||||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
||||
|
||||
// Griffin layer internal activations
|
||||
std::array<float, kBatchSize * kModelDim> griffin_x;
|
||||
std::array<float, kBatchSize * kModelDim> griffin_y;
|
||||
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
|
||||
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
|
||||
};
|
||||
|
||||
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
|
||||
|
|
@ -353,8 +450,13 @@ struct GemmaInterface {
|
|||
|
||||
template <class Config>
|
||||
KVCache CreateKVCache() {
|
||||
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen);
|
||||
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
|
||||
return CreateKVCache(
|
||||
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen,
|
||||
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||
Config::kModelDim,
|
||||
NumGriffinLayers<Config>() * Config::kModelDim);
|
||||
}
|
||||
|
||||
KVCache CreateKVCache(Model type) {
|
||||
|
|
@ -363,6 +465,8 @@ KVCache CreateKVCache(Model type) {
|
|||
return CreateKVCache<ConfigGemma2B>();
|
||||
case Model::GEMMA_7B:
|
||||
return CreateKVCache<ConfigGemma7B>();
|
||||
case Model::GRIFFIN_2B:
|
||||
return CreateKVCache<ConfigGriffin2B>();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||
}
|
||||
|
|
@ -379,7 +483,15 @@ class GemmaTokenizerImpl : public GemmaTokenizer {
|
|||
}
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<int>* pieces) const override {
|
||||
return impl_->Encode(input, pieces).ok();
|
||||
if constexpr (kShowTokenization) {
|
||||
bool is_ok = impl_->Encode(input, pieces).ok();
|
||||
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
|
||||
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
|
||||
}
|
||||
return is_ok;
|
||||
} else {
|
||||
return impl_->Encode(input, pieces).ok();
|
||||
}
|
||||
}
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool Decode(const std::vector<int>& ids,
|
||||
|
|
@ -442,6 +554,119 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t batch_idx, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
|
||||
// X / Y linear layers.
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
TwoMatVecAdd<true, kModelDim, kModelDim>(
|
||||
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
/*add0=*/layer_weights->griffin_linear_x_biases.data(),
|
||||
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x,
|
||||
/*out1=*/y, pool);
|
||||
Gelu(y, kModelDim);
|
||||
|
||||
// Conv1D.
|
||||
{
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
||||
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)];
|
||||
cache[0] = x;
|
||||
for (size_t i = 1; i < kConv1dWidth; i++) {
|
||||
cache[i] =
|
||||
kv_cache.conv1d_cache.get() + layer_offset +
|
||||
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
|
||||
}
|
||||
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() +
|
||||
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
||||
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() +
|
||||
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||
}
|
||||
hn::Store(hn::Add(accum0, accum1), df, x + i);
|
||||
hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i);
|
||||
}
|
||||
}
|
||||
|
||||
// RGLRU
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
constexpr size_t kHeadDim = kModelDim / kHeads;
|
||||
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
size_t head_offset = head * kHeadDim;
|
||||
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
|
||||
layer_weights->griffin_gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (kHeads + head), x + head_offset,
|
||||
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset,
|
||||
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim +
|
||||
head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
Sigmoid(a + head_offset, kHeadDim);
|
||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin_a.data() + head_offset, fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
|
||||
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
|
||||
auto log_a = hn::Load(df, a + head_offset + i);
|
||||
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||
auto a = hn::Exp(df, log_a);
|
||||
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
|
||||
if (pos == 0) {
|
||||
x_multiplier = hn::Set(df, 1.0);
|
||||
}
|
||||
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||
|
||||
// Join branches.
|
||||
auto yv = hn::Load(df, y + head_offset + i);
|
||||
auto pre_out = hn::Mul(yv, new_x);
|
||||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
});
|
||||
|
||||
// Final linear layer.
|
||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<true, kModelDim, kModelDim>(
|
||||
layer_weights->griffin_linear_out_w, 0, x,
|
||||
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool);
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
|
|
@ -462,6 +687,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
|
||||
size_t cache_pos = pos;
|
||||
size_t cache_num = pos + 1;
|
||||
if constexpr (TConfig::kUseLocalAttention) {
|
||||
cache_pos %= TConfig::kSeqLen;
|
||||
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
|
||||
}
|
||||
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||
|
|
@ -480,7 +712,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
||||
v_offset, x, k, v);
|
||||
|
||||
Rope(k, kQKVDim, pos);
|
||||
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
};
|
||||
|
||||
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||
|
|
@ -491,24 +723,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
head * TConfig::kSeqLen +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
|
||||
Rope(q, kQKVDim, pos);
|
||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
// Compute Q dot K scores
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
Softmax(head_att, pos + 1);
|
||||
Softmax(head_att, cache_num);
|
||||
|
||||
// Weighted summation
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||
|
|
@ -520,22 +752,34 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
head == 0
|
||||
? activations.att_post2.data() + batch_idx * kModelDim
|
||||
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
||||
head * kModelDim * kQKVDim, att_out,
|
||||
head_out);
|
||||
if (head == 0) {
|
||||
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
|
||||
layer_weights->attention_output_biases.data(), head_out);
|
||||
} else {
|
||||
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
||||
head * kModelDim * kQKVDim, att_out,
|
||||
head_out);
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr (kHeads == kKVHeads) {
|
||||
// Multi-Head Attention
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head_offset = head * 3 * kQKVDim * kModelDim;
|
||||
// linear projections to QKV
|
||||
const size_t head_offset = TConfig::kInterleaveQKV
|
||||
? 3 * kQKVDim * kModelDim
|
||||
: kQKVDim * kModelDim;
|
||||
const size_t mat_offset =
|
||||
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
|
||||
const size_t q_offset = head * head_offset + 0 * mat_offset;
|
||||
const size_t k_offset = head * head_offset + 1 * mat_offset;
|
||||
const size_t v_offset = head * head_offset + 2 * mat_offset;
|
||||
|
||||
ProjQ(head, head_offset);
|
||||
ProjQ(head, q_offset);
|
||||
|
||||
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
|
||||
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
|
||||
const size_t kv_offset =
|
||||
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||
|
||||
ProjKV(k_offset, v_offset, kv_offset);
|
||||
|
||||
|
|
@ -546,7 +790,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
||||
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
|
||||
ProjKV(k_offset, v_offset, kv_offset);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -581,13 +827,13 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
|
||||
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
||||
// but separating them could help on NUMA e.g. multiple sockets.
|
||||
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w,
|
||||
kFFHiddenDim * kModelDim, vec, out_mul,
|
||||
pool);
|
||||
|
||||
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
|
||||
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool);
|
||||
// Gate, will go through the nonlinearity.
|
||||
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, 0, vec, out,
|
||||
pool);
|
||||
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, 0, vec,
|
||||
layer_weights->ffw_gating_biases.data(), out, pool);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
|
|
@ -598,8 +844,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
}
|
||||
|
||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||
MatVec<kModelDim, kFFHiddenDim>(
|
||||
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->ffw_output_biases.data(),
|
||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
}
|
||||
|
||||
|
|
@ -639,15 +886,23 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
});
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
Attention<kBatchSize>(pos, token_idx, layer, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
||||
|
|
@ -678,9 +933,7 @@ template <typename WeightArrayT, class TConfig>
|
|||
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data(), kModelDim);
|
||||
|
||||
|
|
@ -688,12 +941,21 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
EmbeddingScaling<TConfig>();
|
||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
RMSNorm(activations.x.data(),
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
|
||||
pool);
|
||||
} else {
|
||||
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
}
|
||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
|
|
@ -707,10 +969,12 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
template <class TConfig>
|
||||
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||
size_t& prompt_size) {
|
||||
if (max_tokens > TConfig::kSeqLen) {
|
||||
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
||||
max_tokens, TConfig::kSeqLen);
|
||||
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||
if (!TConfig::kUseLocalAttention) {
|
||||
if (max_tokens > TConfig::kSeqLen) {
|
||||
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
||||
max_tokens, TConfig::kSeqLen);
|
||||
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||
}
|
||||
}
|
||||
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
|
|
@ -720,12 +984,14 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
max_generated_tokens = max_tokens - 1;
|
||||
}
|
||||
|
||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||
fprintf(stderr,
|
||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
||||
"%d, truncating.\n",
|
||||
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
||||
prompt_size = max_tokens - max_generated_tokens;
|
||||
if (!TConfig::kUseLocalAttention) {
|
||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||
fprintf(stderr,
|
||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
||||
"%d, truncating.\n",
|
||||
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
||||
prompt_size = max_tokens - max_generated_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -935,6 +1201,19 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
|||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
|
|
@ -951,6 +1230,15 @@ float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
|||
inner_pool, verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||
size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
inner_pool, verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
// if weights = null, which happens during the first call where we attempt to
|
||||
// load from cache.
|
||||
|
|
@ -967,6 +1255,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
||||
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
||||
|
|
@ -978,9 +1267,33 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
||||
CALL_FUNC("gating_ein", gating_einsum_w);
|
||||
CALL_FUNC("linear_w", linear_w);
|
||||
CALL_FUNC("qkv_ein", qkv_einsum_w);
|
||||
CALL_FUNC("att_ein", attn_vec_einsum_w);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
CALL_FUNC("qkv_ein", qkv_einsum_w);
|
||||
CALL_FUNC("att_ein", attn_vec_einsum_w);
|
||||
} else {
|
||||
CALL_FUNC("gr_lin_x_w", griffin_linear_x_w);
|
||||
CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases);
|
||||
CALL_FUNC("gr_lin_y_w", griffin_linear_y_w);
|
||||
CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases);
|
||||
CALL_FUNC("gr_lin_out_w", griffin_linear_out_w);
|
||||
CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases);
|
||||
CALL_FUNC("gr_conv_w", griffin_conv_w);
|
||||
CALL_FUNC("gr_conv_b", griffin_conv_biases);
|
||||
CALL_FUNC("gr_gate_w", griffin_gate_w);
|
||||
CALL_FUNC("gr_gate_b", griffin_gate_biases);
|
||||
CALL_FUNC("gr_a", griffin_a);
|
||||
}
|
||||
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||
|
||||
if (TConfig::kFFBiases) {
|
||||
CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
||||
CALL_FUNC("ffw_out_b", ffw_output_biases);
|
||||
}
|
||||
|
||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||
type == LayerAttentionType::kGemma) {
|
||||
CALL_FUNC("attn_ob", attention_output_biases);
|
||||
}
|
||||
#undef CALL_FUNC
|
||||
}
|
||||
}
|
||||
|
|
@ -1011,10 +1324,18 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
|||
if (TConfig::kNumTensorScales > 0) {
|
||||
size_t scale_pos = 0;
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
|
||||
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
||||
} else {
|
||||
layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]);
|
||||
}
|
||||
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->linear_w.set_scale(scales[scale_pos++]);
|
||||
}
|
||||
|
|
@ -1031,6 +1352,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
|
|||
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
|
@ -1044,6 +1367,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
|
|||
return LoadWeights<ConfigGemma2B>(weights, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return LoadWeights<ConfigGemma7B>(weights, pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LoadWeights<ConfigGriffin2B>(weights, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
|
@ -1089,6 +1414,9 @@ void CompressWeightsT(gcpp::Model model, const Path& weights,
|
|||
case Model::GEMMA_7B:
|
||||
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
|
@ -1106,13 +1434,29 @@ HWY_EXPORT(LoadWeightsT);
|
|||
HWY_EXPORT(CompressWeightsT);
|
||||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
HWY_EXPORT(GenerateGriffin2B);
|
||||
HWY_EXPORT(ComputeCrossEntropy2B);
|
||||
HWY_EXPORT(ComputeCrossEntropy7B);
|
||||
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv_cache_size, size_t rglru_cache_size) {
|
||||
KVCache kv_cache = {};
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
if (size_cache_pos != 0) {
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache =
|
||||
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
}
|
||||
if (conv_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
||||
conv_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
||||
}
|
||||
if (rglru_cache_size != 0) {
|
||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
||||
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
||||
}
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
|
|
@ -1136,6 +1480,7 @@ void GemmaImpl<ConfigGemma2B>::Generate(
|
|||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
|
|
@ -1148,6 +1493,18 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGriffin2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
|
|
@ -1164,6 +1521,14 @@ float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
|||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
|
|
@ -1190,6 +1555,9 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
|||
case Model::GEMMA_7B:
|
||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||
}
|
||||
|
|
@ -1240,5 +1608,42 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
|||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
|
||||
"2b-it", "7b-it", "gr2b-it"};
|
||||
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
|
||||
Model::GRIFFIN_2B, Model::GEMMA_2B,
|
||||
Model::GEMMA_7B, Model::GRIFFIN_2B};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
|
||||
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT};
|
||||
} // namespace
|
||||
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training) {
|
||||
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
|
||||
static char kErrorMessageBuffer[kNum * 8 + 1024];
|
||||
kErrorMessageBuffer[0] = 0;
|
||||
strcat(kErrorMessageBuffer,
|
||||
"Invalid or missing model flag, need to specify one of ");
|
||||
for (size_t i = 0; i + 1 < kNum; i++) {
|
||||
strcat(kErrorMessageBuffer, kModelFlags[i]);
|
||||
strcat(kErrorMessageBuffer, ", ");
|
||||
}
|
||||
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]);
|
||||
strcat(kErrorMessageBuffer, ".");
|
||||
std::string model_type_lc = model_flag;
|
||||
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
for (size_t i = 0; i < kNum; i++) {
|
||||
if (kModelFlags[i] == model_type_lc) {
|
||||
model = kModelTypes[i];
|
||||
training = kModelTraining[i];
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
18
gemma.h
18
gemma.h
|
|
@ -42,15 +42,24 @@ constexpr bool kSystemPrompt = false;
|
|||
|
||||
struct KVCache {
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
||||
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
||||
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
rglru_cache; // kModelDim * kNumGriffinLayers
|
||||
};
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B };
|
||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training);
|
||||
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
|
@ -79,7 +88,8 @@ struct Gemma {
|
|||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
|
|
|
|||
12
run.cc
12
run.cc
|
|
@ -27,8 +27,6 @@
|
|||
#include "compression/compress.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h" // Gemma
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -36,8 +34,12 @@
|
|||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // HasHelp
|
||||
|
||||
static constexpr bool kVerboseLogTokens = false;
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr std::string_view kAsciiArtBanner = R""(
|
||||
|
|
@ -203,6 +205,12 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
std::cerr << "\n"
|
||||
<< "[ Reading prompt ] " << std::flush;
|
||||
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||
|
|
|
|||
51
util/app.h
51
util/app.h
|
|
@ -125,46 +125,21 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return result;
|
||||
}
|
||||
gcpp::Model ModelType() const { return model_type; }
|
||||
|
||||
gcpp::Model ModelType() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
||||
return gcpp::Model::GEMMA_2B;
|
||||
} else {
|
||||
return gcpp::Model::GEMMA_7B;
|
||||
}
|
||||
}
|
||||
|
||||
gcpp::ModelTraining ModelTraining() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
||||
return gcpp::ModelTraining::GEMMA_PT;
|
||||
} else {
|
||||
return gcpp::ModelTraining::GEMMA_IT;
|
||||
}
|
||||
}
|
||||
gcpp::ModelTraining ModelTraining() const { return model_training; }
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type.empty()) {
|
||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||
"2b-it, or 7b-it.";
|
||||
}
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
||||
"7b-it.";
|
||||
}
|
||||
const char* parse_result =
|
||||
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
||||
if (parse_result) return parse_result;
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
if (!tokenizer.exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
}
|
||||
if (!compressed_weights.path.empty()) {
|
||||
if (weights.path.empty()) {
|
||||
weights = compressed_weights;
|
||||
|
|
@ -186,7 +161,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Path compressed_weights;
|
||||
std::string model_type;
|
||||
std::string model_type_str;
|
||||
Model model_type;
|
||||
enum ModelTraining model_training;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
|
|
@ -196,10 +173,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
"Path name of model weights (.sbs) file.\n Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Alias for --weights.");
|
||||
visitor(model_type, "model", std::string(),
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue