mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into deinterleave-vecs
This commit is contained in:
commit
564937ede6
|
|
@ -46,8 +46,10 @@ cc_test(
|
|||
deps = [
|
||||
":ops",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
|
@ -27,8 +28,8 @@ class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
|
|||
|
||||
std::pair<std::string, int> QueryModel(
|
||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const std::string& input, gcpp::LayersOutputT* layers_output) {
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
|
||||
gcpp::LayersOutputT* layers_output) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
|
||||
|
|
@ -55,8 +56,7 @@ std::pair<std::string, int> QueryModel(
|
|||
}
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
inner_pool, stream_token, accept_token, gen, app.verbosity,
|
||||
layers_output);
|
||||
stream_token, accept_token, gen, app.verbosity, layers_output);
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
|
|
@ -92,7 +92,6 @@ int main(int argc, char** argv) {
|
|||
gcpp::LayersOutputT* layers_output =
|
||||
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
|
|
@ -112,7 +111,7 @@ int main(int argc, char** argv) {
|
|||
return EXIT_FAILURE;
|
||||
}
|
||||
const auto [answer, token_count] = QueryModel(
|
||||
model, args, app, kv_cache, inner_pool, pool, prompt, layers_output);
|
||||
model, args, app, kv_cache, pool, prompt, layers_output);
|
||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||
|
||||
if (log_layers_output) {
|
||||
|
|
|
|||
|
|
@ -58,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) {
|
|||
|
||||
std::pair<std::string, int> QueryModel(
|
||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const std::string& input) {
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
|
||||
|
|
@ -90,7 +89,7 @@ std::pair<std::string, int> QueryModel(
|
|||
}
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
inner_pool, stream_token, accept_token, gen, app.verbosity);
|
||||
stream_token, accept_token, gen, app.verbosity);
|
||||
if (app.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
|
|
@ -131,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) {
|
|||
|
||||
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const std::string& golden_path) {
|
||||
hwy::ThreadPool& pool, const std::string& golden_path) {
|
||||
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||
load_goldens(golden_path);
|
||||
int correct_answers = 0;
|
||||
|
|
@ -140,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
|||
const double time_start = hwy::platform::Now();
|
||||
for (const auto& [question, expected_answer] : queries_answers) {
|
||||
const auto [answer, token_count] =
|
||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, question);
|
||||
QueryModel(model, args, app, kv_cache, pool, question);
|
||||
total_tokens += token_count;
|
||||
if (answer.find(expected_answer) != std::string::npos) {
|
||||
correct_answers++;
|
||||
|
|
@ -164,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
|||
|
||||
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const gcpp::Path& text) {
|
||||
hwy::ThreadPool& pool, const gcpp::Path& text) {
|
||||
std::string prompt("Here is some text to summarize:\n");
|
||||
prompt.append(ReadFile(text));
|
||||
prompt.append("\nSummarize this text.\n");
|
||||
const double time_start = hwy::platform::Now();
|
||||
const auto [answer, token_count] =
|
||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt);
|
||||
QueryModel(model, args, app, kv_cache, pool, prompt);
|
||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||
LogSpeedStats(time_start, token_count);
|
||||
return EXIT_SUCCESS;
|
||||
|
|
@ -179,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
|||
|
||||
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const gcpp::Path& text, size_t batch_tokens) {
|
||||
hwy::ThreadPool& pool, const gcpp::Path& text,
|
||||
size_t batch_tokens) {
|
||||
std::string input = ReadFile(text);
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
|
|
@ -197,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
|||
auto kv_cache = CreateKVCache(model_type);
|
||||
float entropy =
|
||||
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
||||
inner_pool, app.verbosity);
|
||||
app.verbosity);
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice;
|
||||
|
|
@ -211,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
|||
|
||||
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const gcpp::Path& json_file, size_t max_questions) {
|
||||
hwy::ThreadPool& pool, const gcpp::Path& json_file,
|
||||
size_t max_questions) {
|
||||
std::ifstream trivia_file(json_file.path);
|
||||
if (!trivia_file) {
|
||||
std::cout << "Could not load file: " << json_file.path << "\n"
|
||||
|
|
@ -225,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
|||
while (std::getline(trivia_file, line)) {
|
||||
json data = json::parse(line);
|
||||
const auto [answer, token_count] = QueryModel(
|
||||
model, args, app, kv_cache, inner_pool, pool, data["question"]);
|
||||
model, args, app, kv_cache, pool, data["question"]);
|
||||
std::cout << answer << "\n";
|
||||
bool correct = false;
|
||||
for (const std::string expected : data["answer"]["aliases"]) {
|
||||
|
|
@ -263,7 +260,6 @@ int main(int argc, char** argv) {
|
|||
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||
}
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
|
|
@ -280,17 +276,16 @@ int main(int argc, char** argv) {
|
|||
if (!benchmark_args.goldens.path.empty()) {
|
||||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
||||
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
||||
golden_path);
|
||||
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
|
||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool,
|
||||
return BenchmarkSummary(model, args, app, kv_cache, pool,
|
||||
benchmark_args.summarize_text);
|
||||
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
||||
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
||||
inner_pool, pool, benchmark_args.cross_entropy,
|
||||
pool, benchmark_args.cross_entropy,
|
||||
benchmark_args.batch_tokens);
|
||||
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
||||
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool,
|
||||
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
|
||||
benchmark_args.trivia_qa,
|
||||
benchmark_args.max_questions);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ struct Args : public ArgsBase<Args> {
|
|||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n"
|
||||
"Path to model weights (.bin) file.\n"
|
||||
" Required argument.");
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
|
|
@ -80,7 +80,7 @@ struct Args : public ArgsBase<Args> {
|
|||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Path name where compressed weights file will be written.\n"
|
||||
"Path name where compressed weights (.sbs) file will be written.\n"
|
||||
" Required argument.");
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@
|
|||
#define GEMMA_TOPK 1
|
||||
#endif // !GEMMA_TOPK
|
||||
|
||||
// Allow changing upper bound on threads as a compiler flag
|
||||
#ifndef GEMMA_MAX_THREADS
|
||||
#define GEMMA_MAX_THREADS 128
|
||||
#endif // !GEMMA_MAX_THREADS
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
|
|
@ -45,6 +50,7 @@ namespace gcpp {
|
|||
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
||||
|
||||
enum class LayerAttentionType {
|
||||
kGemma,
|
||||
|
|
|
|||
246
gemma/gemma.cc
246
gemma/gemma.cc
|
|
@ -399,9 +399,9 @@ struct Activations {
|
|||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
||||
static constexpr size_t kCachePosSize =
|
||||
TConfig::kGemmaLayers * kKVHeads * kQKVDim;
|
||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
|
||||
TConfig::kGemmaLayers * kCacheLayerSize;
|
||||
|
||||
std::array<float, kBatchSize * kModelDim> x; // input
|
||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||
|
|
@ -421,6 +421,10 @@ struct Activations {
|
|||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
||||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
||||
|
||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
||||
// per-thread storage.
|
||||
std::array<float, kModelDim * kMaxThreads> even_odd;
|
||||
|
||||
// Griffin layer internal activations
|
||||
static constexpr size_t kGriffinDim =
|
||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||
|
|
@ -440,15 +444,13 @@ struct GemmaInterface {
|
|||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) = 0;
|
||||
|
||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
int verbosity) = 0;
|
||||
};
|
||||
|
||||
|
|
@ -535,13 +537,12 @@ struct GemmaImpl : public GemmaInterface {
|
|||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937&, int verbosity,
|
||||
LayersOutputT* layers_output) override;
|
||||
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
int verbosity) override;
|
||||
|
||||
GemmaTokenizerImpl tokenizer;
|
||||
|
|
@ -578,13 +579,14 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr bool kAdd = true;
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
|
||||
// X / Y linear layers.
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
TwoMatVecAdd<true, kModelDim, kModelDim>(
|
||||
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
||||
|
|
@ -634,7 +636,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
constexpr size_t kHeadDim = kModelDim / kHeads;
|
||||
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
size_t head_offset = head * kHeadDim;
|
||||
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
|
||||
TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>(
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (kHeads + head), x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
|
||||
|
|
@ -673,9 +675,10 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
|
||||
// Final linear layer.
|
||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<true, kModelDim, kModelDim>(
|
||||
MatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_out_w, 0, x,
|
||||
layer_weights->griffin.linear_out_biases.data(), out_ptr, pool);
|
||||
layer_weights->griffin.linear_out_biases.data(),
|
||||
activations.even_odd.data(), out_ptr, pool);
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
|
|
@ -707,26 +710,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
|
||||
MatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
head_offset + 0 * kQKVDim * kModelDim, x, q);
|
||||
};
|
||||
|
||||
auto ProjKV = [&](size_t k_offset, size_t v_offset,
|
||||
size_t kv_offset) HWY_ATTR {
|
||||
float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset;
|
||||
float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset;
|
||||
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
||||
v_offset, x, k, v);
|
||||
|
||||
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
};
|
||||
|
||||
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||
auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR {
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
|
|
@ -741,7 +725,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
|
|
@ -754,7 +738,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
// linear projection from kQKVDim back to kModelDim, sum projections
|
||||
|
|
@ -763,20 +747,21 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
head == 0
|
||||
? activations.att_post2.data() + batch_idx * kModelDim
|
||||
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
float* even_odd = activations.even_odd.data() + thread * kQKVDim;
|
||||
if (head == 0) {
|
||||
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
|
||||
layer_weights->attention_output_biases.data(), head_out);
|
||||
layer_weights->attention_output_biases.data(), even_odd, head_out);
|
||||
} else {
|
||||
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
||||
head * kModelDim * kQKVDim, att_out,
|
||||
head_out);
|
||||
even_odd, head_out);
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr (kHeads == kKVHeads) {
|
||||
// Multi-Head Attention
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
const size_t head_offset = TConfig::kInterleaveQKV
|
||||
? 3 * kQKVDim * kModelDim
|
||||
|
|
@ -787,28 +772,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
const size_t k_offset = head * head_offset + 1 * mat_offset;
|
||||
const size_t v_offset = head * head_offset + 2 * mat_offset;
|
||||
|
||||
ProjQ(head, q_offset);
|
||||
// ProjQ
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
MatVecLoop<kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data() + thread * kModelDim, q);
|
||||
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||
// ProjKV
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
||||
float* HWY_RESTRICT v = k + kQKVDim;
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
k_offset, v_offset, x, k, v);
|
||||
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
|
||||
ProjKV(k_offset, v_offset, kv_offset);
|
||||
|
||||
Attn(head, head * kQKVDim);
|
||||
Attn(head, head * kQKVDim * 2, thread);
|
||||
});
|
||||
} else {
|
||||
// Multi-Query Attention
|
||||
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
||||
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||
activations.even_odd.data(), q, pool);
|
||||
|
||||
ProjKV(k_offset, v_offset, kv_offset);
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
|
||||
cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize;
|
||||
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data(), kv, pool);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
ProjQ(head, head * kQKVDim * kModelDim);
|
||||
Attn(head, 0);
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
Attn(head, 0, thread);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -828,6 +826,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
||||
|
|
@ -836,15 +835,15 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
|
||||
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
||||
|
||||
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
||||
// but separating them could help on NUMA e.g. multiple sockets.
|
||||
// Same matrix, first and second half of rows. Could fuse into one MatVec.
|
||||
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
|
||||
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool);
|
||||
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd,
|
||||
out_mul, pool);
|
||||
// Gate, will go through the nonlinearity.
|
||||
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, 0, vec,
|
||||
layer_weights->ffw_gating_biases.data(), out, pool);
|
||||
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
|
|
@ -857,7 +856,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->ffw_output_biases.data(),
|
||||
layer_weights->ffw_output_biases.data(), even_odd,
|
||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
}
|
||||
|
||||
|
|
@ -880,8 +879,7 @@ template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
|||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const WeightArrayT& weights,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool) {
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
|
|
@ -924,19 +922,17 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
}
|
||||
|
||||
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
||||
pool.Run(
|
||||
0, num_tokens,
|
||||
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
FFW<kBatchSize>(activations, token_idx, layer_weights, inner_pool);
|
||||
FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
|
||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
});
|
||||
}
|
||||
} // foreach layer
|
||||
|
||||
pool.Run(
|
||||
|
|
@ -950,8 +946,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
template <typename WeightArrayT, class TConfig>
|
||||
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
LayersOutputT* layers_output) {
|
||||
hwy::ThreadPool& pool, LayersOutputT* layers_output) {
|
||||
if (layers_output != nullptr) {
|
||||
float token_f = token;
|
||||
(*layers_output)(pos, "Tokens", &token_f, 1);
|
||||
|
|
@ -1033,8 +1028,7 @@ template <class TConfig>
|
|||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
|
|
@ -1077,7 +1071,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||
const int* batch_tokens = prompt.data() + pos_offset;
|
||||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||
prefill_activations, kv_cache, pool, inner_pool);
|
||||
prefill_activations, kv_cache, pool);
|
||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||
if (!stream_token(batch_tokens[idx], 0.0f)) return;
|
||||
}
|
||||
|
|
@ -1105,7 +1099,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool,
|
||||
layers_output);
|
||||
float* final_activation = activations.x.data();
|
||||
// The condition below is always true if we are doing Prefill above.
|
||||
|
|
@ -1114,9 +1108,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
if (is_generating_phase) {
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Generation phase
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
||||
0, final_activation,
|
||||
activations.logits.data(), pool);
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
weights.embedder_input_embedding, 0, final_activation,
|
||||
activations.even_odd.data(), activations.logits.data(), pool);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
||||
|
|
@ -1171,8 +1165,7 @@ void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
|||
template <class TConfig>
|
||||
float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
|
|
@ -1196,11 +1189,11 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||
total_entropy / std::log(2.0) / (pos + 1));
|
||||
}
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool,
|
||||
/*layers_output=*/nullptr);
|
||||
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
|
||||
activations.x.data(),
|
||||
activations.logits.data(), pool);
|
||||
MatVec<kVocabSize, kModelDim>(
|
||||
weights.embedder_input_embedding, 0, activations.x.data(),
|
||||
activations.even_odd.data(), activations.logits.data(), pool);
|
||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||
memcpy(logits.data(), activations.logits.data(),
|
||||
kVocabSize * sizeof(logits[0]));
|
||||
|
|
@ -1215,62 +1208,59 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
|||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, layers_output);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, layers_output);
|
||||
}
|
||||
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, layers_output);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
inner_pool, verbosity);
|
||||
verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
inner_pool, verbosity);
|
||||
verbosity);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||
size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
int verbosity) {
|
||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||
inner_pool, verbosity);
|
||||
verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
|
|
@ -1477,9 +1467,8 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
|||
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
||||
KVCache kv_cache = {};
|
||||
if (size_cache_pos != 0) {
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache =
|
||||
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.kv_cache =
|
||||
hwy::AllocateAligned<float>(seq_len * size_cache_pos * 2);
|
||||
}
|
||||
if (conv1d_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||
|
|
@ -1507,12 +1496,12 @@ template <>
|
|||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
|
|
@ -1520,50 +1509,49 @@ template <>
|
|||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
||||
layers_output);
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGriffin2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
|
|
@ -1607,13 +1595,13 @@ const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
|
|||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
start_pos, kv_cache, pool, stream_token, accept_token,
|
||||
gen, verbosity, layers_output);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
|
|
@ -1621,10 +1609,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen) {
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
GenerateGemma(
|
||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool,
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
||||
/*layers_output=*/nullptr);
|
||||
}
|
||||
|
|
@ -1637,11 +1624,10 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
|||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity) {
|
||||
hwy::ThreadPool& pool, int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
const float result = gemma.impl_->ComputeCrossEntropy(
|
||||
max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||
max_tokens, prompt, kv_cache, pool, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,9 +44,7 @@ constexpr bool kSystemPrompt = false;
|
|||
|
||||
struct KVCache {
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim
|
||||
kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
|
|
@ -104,13 +102,12 @@ using AcceptFunc = std::function<bool(int)>;
|
|||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, LayersOutputT* layers_output = nullptr);
|
||||
|
||||
// Convenience function for the common case:
|
||||
// - Bundle runtime parameters as RuntimeConfig
|
||||
// - No ThreadPool within ThreadPool (inner_pool = dummy)
|
||||
// - All tokens accepted
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
|
|
@ -122,8 +119,7 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
|||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
int verbosity);
|
||||
hwy::ThreadPool& pool, int verbosity);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ class GemmaTest : public ::testing::Test {
|
|||
: weights("./2b-it-mqa.sbs"),
|
||||
tokenizer("./tokenizer.spm"),
|
||||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||
inner_pool(0),
|
||||
model_type(gcpp::Model::GEMMA_2B),
|
||||
model(tokenizer, weights, model_type, pool) {
|
||||
kv_cache = CreateKVCache(model_type);
|
||||
|
|
@ -60,8 +59,8 @@ class GemmaTest : public ::testing::Test {
|
|||
gcpp::GenerateGemma(
|
||||
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
||||
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
inner_pool, stream_token,
|
||||
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
||||
stream_token, /*accept=*/[](int) { return true; }, gen,
|
||||
/*verbosity=*/0);
|
||||
std::string response_text;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||
return response_text;
|
||||
|
|
@ -71,8 +70,7 @@ class GemmaTest : public ::testing::Test {
|
|||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
||||
kv_cache, pool, inner_pool,
|
||||
/*verbosity=*/0) /
|
||||
kv_cache, pool, /*verbosity=*/0) /
|
||||
prompt_string.size();
|
||||
}
|
||||
|
||||
|
|
@ -89,7 +87,6 @@ class GemmaTest : public ::testing::Test {
|
|||
gcpp::Path tokenizer;
|
||||
gcpp::KVCache kv_cache;
|
||||
hwy::ThreadPool pool;
|
||||
hwy::ThreadPool inner_pool;
|
||||
gcpp::Model model_type = {};
|
||||
gcpp::Gemma model;
|
||||
};
|
||||
|
|
|
|||
58
gemma/ops.h
58
gemma/ops.h
|
|
@ -129,11 +129,13 @@ HWY_INLINE void ToEvenOddF32(
|
|||
}
|
||||
|
||||
// Simple version without tiling nor threading.
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecAddLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -149,7 +151,6 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
|
||||
size_t kCapacity>
|
||||
|
|
@ -157,33 +158,39 @@ HWY_INLINE void MatVecAddLoop(
|
|||
const CompressedArray<hwy::bfloat16_t, kCapacity>& mat,
|
||||
const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add,
|
||||
const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecAddLoop");
|
||||
|
||||
// Sanity check: we can write without race conditions.
|
||||
if (HWY_IS_TSAN) {
|
||||
even_odd[0] = hwy::ConvertScalarTo<float>(vec_aligned[0]);
|
||||
even_odd[kInner - 1] = -even_odd[0];
|
||||
}
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
|
||||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
||||
|
||||
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||
if constexpr (kAdd) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
||||
Dot<true>(df, mat, row_ofs, vec_dequant.get(), kInner);
|
||||
Dot<true>(df, mat, row_ofs, even_odd, kInner);
|
||||
} else {
|
||||
out[idx_row] = Dot<true>(df, mat, row_ofs, vec_dequant.get(), kInner);
|
||||
out[idx_row] = Dot<true>(df, mat, row_ofs, even_odd, kInner);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
MatVecAddLoop<false, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), out);
|
||||
MatVecAddLoop</*kAdd=*/false, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
|
||||
out);
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
|
|
@ -221,7 +228,7 @@ HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1) {
|
||||
TwoOfsMatVecAddLoop<false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||
out0, out1);
|
||||
}
|
||||
|
|
@ -307,11 +314,21 @@ template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
|
|||
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
||||
// Sanity check: each thread can write without race conditions.
|
||||
if (HWY_IS_TSAN) {
|
||||
pool.Run(
|
||||
0, pool.NumWorkers(), [even_odd](uint64_t /*task*/, size_t thread) {
|
||||
even_odd[thread * kInner] = -static_cast<float>(thread);
|
||||
even_odd[thread * kInner + kInner - 1] = static_cast<float>(thread);
|
||||
});
|
||||
}
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
|
|
@ -340,6 +357,7 @@ template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
|||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("MatVecAdd");
|
||||
|
||||
|
|
@ -352,25 +370,25 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
|
||||
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
|
||||
) {
|
||||
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
|
||||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
||||
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
||||
detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_dequant.get(), add, out, pool);
|
||||
mat, mat_ofs, even_odd, add, even_odd, out, pool);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, add, out, pool);
|
||||
mat, mat_ofs, vec_aligned, add, even_odd, out, pool);
|
||||
}
|
||||
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
MatVecAdd<false, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), out,
|
||||
pool);
|
||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
||||
hwy::ThreadPool& pool) {
|
||||
MatVecAdd</*kAdd=*/false, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
|
||||
out, pool);
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
|
|
@ -523,7 +541,7 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1,
|
|||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
|
||||
hwy::ThreadPool& pool) {
|
||||
TwoMatVecAdd<false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
TwoMatVecAdd</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||
out0, out1, pool);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,11 +17,15 @@
|
|||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -375,6 +379,7 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
|
|||
template <size_t length>
|
||||
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
|
||||
HWY_ASSERT(vec);
|
||||
for (size_t idx = 0; idx < length; idx++) {
|
||||
vec[idx] = static_cast<float>(idx + offset);
|
||||
}
|
||||
|
|
@ -388,8 +393,9 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
|||
const hwy::AlignedFreeUniquePtr<float[]>& add) {
|
||||
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
|
||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
||||
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(uncompressed_mat && out);
|
||||
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
||||
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
||||
out[idx_row] = add[idx_row];
|
||||
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
||||
|
|
@ -418,12 +424,15 @@ void TestMatVecAdd() {
|
|||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
||||
hwy::AllocateAligned<float>(kInner * pool.NumWorkers());
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
MatVecAdd<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
||||
actual_out.get(), pool);
|
||||
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
||||
MatVecAdd</*kAdd=*/true, kOuter, kInner>(
|
||||
mat, 0, vec.get(), add.get(), even_odd.get(), actual_out.get(), pool);
|
||||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
||||
|
|
@ -433,12 +442,15 @@ void TestMatVecAddLoop() {
|
|||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
||||
hwy::AllocateAligned<float>(kInner);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
||||
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
||||
actual_out.get());
|
||||
even_odd.get(), actual_out.get());
|
||||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
||||
|
|
@ -459,6 +471,8 @@ void TestTwoMatVecAdd() {
|
|||
hwy::AllocateAligned<float>(kOuter);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||
expected_out1 && actual_out1);
|
||||
TwoMatVecAdd<true, kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(),
|
||||
actual_out1.get(), pool);
|
||||
|
|
@ -481,6 +495,8 @@ void TestTwoOfsMatVecAddLoop() {
|
|||
hwy::AllocateAligned<float>(kOuter);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||
expected_out1 && actual_out1);
|
||||
TwoOfsMatVecAddLoop<true, kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(),
|
||||
actual_out1.get());
|
||||
|
|
|
|||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -94,9 +94,8 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
|||
|
||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
const InferenceArgs& args, int verbosity,
|
||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -209,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
|
|
@ -229,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
|
|
@ -271,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
ReplGemma(
|
||||
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference,
|
||||
app.verbosity,
|
||||
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -96,8 +96,9 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
}
|
||||
|
||||
static inline size_t GetSupportedThreadCount() {
|
||||
return static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
return static_cast<size_t>(
|
||||
std::clamp(static_cast<int>(std::thread::hardware_concurrency()) - 2, 1,
|
||||
HWY_MIN(static_cast<int>(kMaxThreads), 18)));
|
||||
}
|
||||
|
||||
Path log; // output
|
||||
|
|
|
|||
Loading…
Reference in New Issue