mirror of https://github.com/google/gemma.cpp.git
Replace mt19937 with new generator to enable parallel sampling
Split it into immutable AesCtrEngine and RngStream Also add RowSpan and Logits span PiperOrigin-RevId: 803336423
This commit is contained in:
parent
5d1693e806
commit
56186193c1
|
|
@ -24,7 +24,6 @@
|
||||||
|
|
||||||
#include <algorithm> // std::shuffle
|
#include <algorithm> // std::shuffle
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <random>
|
|
||||||
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
|
|
@ -104,8 +103,8 @@ struct TestPlateaus {
|
||||||
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
|
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::random_device rd; // NOLINT
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
std::mt19937 rng(rd());
|
RngStream rng(engine, 0);
|
||||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||||
|
|
||||||
NuqStream::ClusterBuf buf;
|
NuqStream::ClusterBuf buf;
|
||||||
|
|
@ -151,8 +150,8 @@ struct TestRamp {
|
||||||
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
|
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::random_device rd; // NOLINT
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
std::mt19937 rng(rd());
|
RngStream rng(engine, 0);
|
||||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||||
|
|
||||||
NuqStream::ClusterBuf buf;
|
NuqStream::ClusterBuf buf;
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -37,17 +36,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|
||||||
if (inference.deterministic) {
|
|
||||||
// Nothing up my sleeve number, at least some upper bits set.
|
|
||||||
gen.seed(0x12345678);
|
|
||||||
} else {
|
|
||||||
// Depending on the library implementation, this may still be deterministic.
|
|
||||||
std::random_device rd; // NOLINT
|
|
||||||
gen.seed(rd());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference)
|
const InferenceArgs& inference)
|
||||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||||
|
|
@ -60,12 +48,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
ctx_);
|
ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
InitGenerator(inference, gen_);
|
|
||||||
|
|
||||||
runtime_config_ = {
|
runtime_config_ = {
|
||||||
.max_generated_tokens = inference.max_generated_tokens,
|
.max_generated_tokens = inference.max_generated_tokens,
|
||||||
.temperature = inference.temperature,
|
.temperature = inference.temperature,
|
||||||
.gen = &gen_,
|
|
||||||
.verbosity = inference.verbosity,
|
.verbosity = inference.verbosity,
|
||||||
};
|
};
|
||||||
inference.CopyTo(runtime_config_);
|
inference.CopyTo(runtime_config_);
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -32,8 +31,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
|
||||||
|
|
||||||
// Return type for query model calls.
|
// Return type for query model calls.
|
||||||
struct QueryResult {
|
struct QueryResult {
|
||||||
std::string response;
|
std::string response;
|
||||||
|
|
@ -107,7 +104,6 @@ class GemmaEnv {
|
||||||
|
|
||||||
int Verbosity() const { return runtime_config_.verbosity; }
|
int Verbosity() const { return runtime_config_.verbosity; }
|
||||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||||
std::mt19937& MutableGen() { return gen_; }
|
|
||||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||||
MatMulEnv& MutableEnv() { return env_; }
|
MatMulEnv& MutableEnv() { return env_; }
|
||||||
|
|
||||||
|
|
@ -115,7 +111,6 @@ class GemmaEnv {
|
||||||
ThreadingContext ctx_;
|
ThreadingContext ctx_;
|
||||||
MatMulEnv env_;
|
MatMulEnv env_;
|
||||||
Gemma gemma_;
|
Gemma gemma_;
|
||||||
std::mt19937 gen_; // Random number generator.
|
|
||||||
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
||||||
RuntimeConfig runtime_config_;
|
RuntimeConfig runtime_config_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -56,11 +56,10 @@ static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
||||||
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
|
void LogTopK(const GemmaTokenizer& tokenizer, Logits logits, size_t k) {
|
||||||
size_t k) {
|
std::vector<std::pair<float, int>> sorted(logits.size());
|
||||||
std::vector<std::pair<float, int>> sorted(len);
|
for (size_t i = 0; i < logits.size(); ++i) {
|
||||||
for (size_t i = 0; i < len; ++i) {
|
sorted[i] = std::make_pair(logits[i], static_cast<int>(i));
|
||||||
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
|
||||||
}
|
}
|
||||||
std::sort(sorted.begin(), sorted.end(),
|
std::sort(sorted.begin(), sorted.end(),
|
||||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||||
|
|
@ -84,9 +83,8 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
|
void CallSoftmax(Logits logits, hwy::Profiler& p) {
|
||||||
hwy::Profiler& p) {
|
Softmax(logits, p, hwy::Profiler::Thread());
|
||||||
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
|
|
@ -107,19 +105,19 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||||
size_t pos = 1;
|
size_t pos = 1;
|
||||||
|
|
||||||
const SampleFunc sample_token = [&](float* probs,
|
const SampleFunc sample_token = [&](size_t qi,
|
||||||
size_t vocab_size) -> TokenAndProb {
|
Logits logits) -> TokenAndProb {
|
||||||
// input is logits, not yet probabilities
|
// input is logits, not yet probabilities
|
||||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
|
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
|
||||||
// We are called for each token, but pos starts at 1. Clamping
|
// We are called for each token, but pos starts at 1. Clamping
|
||||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||||
HWY_ASSERT(pos < prompt.size());
|
HWY_ASSERT(pos < prompt.size());
|
||||||
const int token = prompt[pos];
|
const int token = prompt[pos];
|
||||||
const float prob = probs[token];
|
const float prob = logits[token];
|
||||||
cross_entropy -= std::max(std::log(prob), -64.0f);
|
cross_entropy -= std::max(std::log(prob), -64.0f);
|
||||||
|
|
||||||
if (verbosity >= 4) {
|
if (verbosity >= 4) {
|
||||||
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
|
LogTopK(gemma.Tokenizer(), logits, 10);
|
||||||
}
|
}
|
||||||
if (verbosity >= 3) {
|
if (verbosity >= 3) {
|
||||||
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
||||||
|
|
@ -139,7 +137,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
RuntimeConfig runtime = {
|
RuntimeConfig runtime = {
|
||||||
.max_generated_tokens = max_generated_tokens - 1,
|
.max_generated_tokens = max_generated_tokens - 1,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.gen = nullptr,
|
|
||||||
.verbosity = verbosity,
|
.verbosity = verbosity,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.sample_func = sample_token,
|
.sample_func = sample_token,
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,6 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
RuntimeConfig runtime_config{
|
RuntimeConfig runtime_config{
|
||||||
.max_generated_tokens = 64,
|
.max_generated_tokens = 64,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.gen = &s_env->MutableGen(),
|
|
||||||
.verbosity = 2,
|
.verbosity = 2,
|
||||||
.batch_stream_token = stream_token,
|
.batch_stream_token = stream_token,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
.max_generated_tokens = 30,
|
.max_generated_tokens = 30,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.gen = &env.MutableGen(),
|
|
||||||
.verbosity = env.Verbosity(),
|
.verbosity = env.Verbosity(),
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -44,7 +44,7 @@ int main(int argc, char** argv) {
|
||||||
for (int arg = 0; arg < argc; ++arg) {
|
for (int arg = 0; arg < argc; ++arg) {
|
||||||
// Find a --reject flag and consume everything after it.
|
// Find a --reject flag and consume everything after it.
|
||||||
if (strcmp(argv[arg], "--reject") == 0) {
|
if (strcmp(argv[arg], "--reject") == 0) {
|
||||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); // NOLINT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -55,11 +55,6 @@ int main(int argc, char** argv) {
|
||||||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
// Initialize random number generator
|
|
||||||
std::mt19937 gen;
|
|
||||||
std::random_device rd; // NOLINT
|
|
||||||
gen.seed(rd());
|
|
||||||
|
|
||||||
// Tokenize instructions.
|
// Tokenize instructions.
|
||||||
std::string prompt = "Write a greeting to the world.";
|
std::string prompt = "Write a greeting to the world.";
|
||||||
const std::vector<int> tokens =
|
const std::vector<int> tokens =
|
||||||
|
|
@ -84,7 +79,6 @@ int main(int argc, char** argv) {
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
.max_generated_tokens = 1024,
|
.max_generated_tokens = 1024,
|
||||||
.temperature = 1.0,
|
.temperature = 1.0,
|
||||||
.gen = &gen,
|
|
||||||
.verbosity = 0,
|
.verbosity = 0,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.accept_token =
|
.accept_token =
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -38,11 +37,7 @@ class SimplifiedGemma {
|
||||||
: ctx_(threading),
|
: ctx_(threading),
|
||||||
env_(ctx_),
|
env_(ctx_),
|
||||||
gemma_(loader, inference, ctx_),
|
gemma_(loader, inference, ctx_),
|
||||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
|
||||||
// Initialize random number generator
|
|
||||||
std::random_device rd;
|
|
||||||
gen_.seed(rd());
|
|
||||||
}
|
|
||||||
|
|
||||||
SimplifiedGemma(int argc, char** argv)
|
SimplifiedGemma(int argc, char** argv)
|
||||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
||||||
|
|
@ -76,7 +71,6 @@ class SimplifiedGemma {
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
.max_generated_tokens = max_generated_tokens,
|
.max_generated_tokens = max_generated_tokens,
|
||||||
.temperature = temperature,
|
.temperature = temperature,
|
||||||
.gen = &gen_,
|
|
||||||
.verbosity = 0,
|
.verbosity = 0,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.accept_token =
|
.accept_token =
|
||||||
|
|
@ -93,6 +87,5 @@ class SimplifiedGemma {
|
||||||
gcpp::MatMulEnv env_;
|
gcpp::MatMulEnv env_;
|
||||||
gcpp::Gemma gemma_;
|
gcpp::Gemma gemma_;
|
||||||
gcpp::KVCache kv_cache_;
|
gcpp::KVCache kv_cache_;
|
||||||
std::mt19937 gen_;
|
|
||||||
std::string validation_error_;
|
std::string validation_error_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -60,18 +60,18 @@ struct ServerState {
|
||||||
std::unique_ptr<Gemma> gemma;
|
std::unique_ptr<Gemma> gemma;
|
||||||
MatMulEnv* env;
|
MatMulEnv* env;
|
||||||
ThreadingContext* ctx;
|
ThreadingContext* ctx;
|
||||||
|
|
||||||
// Session-based KV cache storage
|
// Session-based KV cache storage
|
||||||
struct Session {
|
struct Session {
|
||||||
std::unique_ptr<KVCache> kv_cache;
|
std::unique_ptr<KVCache> kv_cache;
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::chrono::steady_clock::time_point last_access;
|
std::chrono::steady_clock::time_point last_access;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_map<std::string, Session> sessions;
|
std::unordered_map<std::string, Session> sessions;
|
||||||
std::mutex sessions_mutex;
|
std::mutex sessions_mutex;
|
||||||
std::mutex inference_mutex;
|
std::mutex inference_mutex;
|
||||||
|
|
||||||
// Cleanup old sessions after 30 minutes of inactivity
|
// Cleanup old sessions after 30 minutes of inactivity
|
||||||
void CleanupOldSessions() {
|
void CleanupOldSessions() {
|
||||||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||||
|
|
@ -84,7 +84,7 @@ struct ServerState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get or create session with KV cache
|
// Get or create session with KV cache
|
||||||
Session& GetOrCreateSession(const std::string& session_id) {
|
Session& GetOrCreateSession(const std::string& session_id) {
|
||||||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||||
|
|
@ -101,24 +101,25 @@ struct ServerState {
|
||||||
std::string GenerateSessionId() {
|
std::string GenerateSessionId() {
|
||||||
static std::atomic<uint64_t> counter{0};
|
static std::atomic<uint64_t> counter{0};
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count()
|
ss << "session_" << std::hex
|
||||||
<< "_" << counter.fetch_add(1);
|
<< std::chrono::steady_clock::now().time_since_epoch().count() << "_"
|
||||||
|
<< counter.fetch_add(1);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wraps messages with start_of_turn markers - handles both with and without roles
|
// Wraps messages with start_of_turn markers - handles both with and without roles
|
||||||
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
|
|
||||||
for (const auto& content : contents) {
|
for (const auto& content : contents) {
|
||||||
if (content.contains("parts")) {
|
if (content.contains("parts")) {
|
||||||
// Check if role is specified (public API format) or not (local format)
|
// Check if role is specified (public API format) or not (local format)
|
||||||
std::string role = content.value("role", "");
|
std::string role = content.value("role", "");
|
||||||
|
|
||||||
for (const auto& part : content["parts"]) {
|
for (const auto& part : content["parts"]) {
|
||||||
if (part.contains("text")) {
|
if (part.contains("text")) {
|
||||||
std::string text = part["text"];
|
std::string text = part["text"];
|
||||||
|
|
||||||
if (role == "user") {
|
if (role == "user") {
|
||||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||||
} else if (role == "model") {
|
} else if (role == "model") {
|
||||||
|
|
@ -131,24 +132,23 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return prompt;
|
return prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse generation config
|
// Parse generation config
|
||||||
RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) {
|
RuntimeConfig ParseGenerationConfig(const json& request) {
|
||||||
RuntimeConfig config;
|
RuntimeConfig config;
|
||||||
config.gen = &gen;
|
|
||||||
config.verbosity = 0;
|
config.verbosity = 0;
|
||||||
|
|
||||||
// Set defaults matching public API
|
// Set defaults matching public API
|
||||||
config.temperature = 1.0f;
|
config.temperature = 1.0f;
|
||||||
config.top_k = 1;
|
config.top_k = 1;
|
||||||
config.max_generated_tokens = 8192;
|
config.max_generated_tokens = 8192;
|
||||||
|
|
||||||
if (request.contains("generationConfig")) {
|
if (request.contains("generationConfig")) {
|
||||||
auto& gen_config = request["generationConfig"];
|
auto& gen_config = request["generationConfig"];
|
||||||
|
|
||||||
if (gen_config.contains("temperature")) {
|
if (gen_config.contains("temperature")) {
|
||||||
config.temperature = gen_config["temperature"].get<float>();
|
config.temperature = gen_config["temperature"].get<float>();
|
||||||
}
|
}
|
||||||
|
|
@ -159,7 +159,7 @@ RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) {
|
||||||
config.max_generated_tokens = gen_config["maxOutputTokens"].get<size_t>();
|
config.max_generated_tokens = gen_config["maxOutputTokens"].get<size_t>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -175,12 +175,12 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
||||||
}}},
|
}}},
|
||||||
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only add finishReason for non-streaming chunks
|
// Only add finishReason for non-streaming chunks
|
||||||
if (!is_streaming_chunk) {
|
if (!is_streaming_chunk) {
|
||||||
response["candidates"][0]["finishReason"] = "STOP";
|
response["candidates"][0]["finishReason"] = "STOP";
|
||||||
}
|
}
|
||||||
|
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -188,11 +188,11 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
||||||
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||||
try {
|
try {
|
||||||
json request = json::parse(req.body);
|
json request = json::parse(req.body);
|
||||||
|
|
||||||
// Get or create session
|
// Get or create session
|
||||||
std::string session_id = request.value("sessionId", GenerateSessionId());
|
std::string session_id = request.value("sessionId", GenerateSessionId());
|
||||||
auto& session = state.GetOrCreateSession(session_id);
|
auto& session = state.GetOrCreateSession(session_id);
|
||||||
|
|
||||||
// Extract prompt from API format
|
// Extract prompt from API format
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
if (request.contains("contents")) {
|
if (request.contains("contents")) {
|
||||||
|
|
@ -202,32 +202,29 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock for inference
|
// Lock for inference
|
||||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||||
|
|
||||||
// Set up runtime config
|
// Set up runtime config
|
||||||
std::mt19937 gen;
|
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||||
RuntimeConfig runtime_config = ParseGenerationConfig(request, gen);
|
|
||||||
|
|
||||||
// Collect full response
|
// Collect full response
|
||||||
std::string full_response;
|
std::string full_response;
|
||||||
runtime_config.stream_token = [&full_response](int token, float) {
|
runtime_config.stream_token = [&full_response](int token, float) {
|
||||||
// Skip EOS token
|
// Skip EOS token
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Tokenize prompt
|
// Tokenize prompt
|
||||||
std::vector<int> tokens = WrapAndTokenize(state.gemma->Tokenizer(),
|
std::vector<int> tokens = WrapAndTokenize(
|
||||||
state.gemma->ChatTemplate(),
|
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
|
||||||
state.gemma->Config().wrapping,
|
state.gemma->Config().wrapping, session.abs_pos, prompt);
|
||||||
session.abs_pos,
|
|
||||||
prompt);
|
|
||||||
|
|
||||||
// Run inference with KV cache
|
// Run inference with KV cache
|
||||||
TimingInfo timing_info = {.verbosity = 0};
|
TimingInfo timing_info = {.verbosity = 0};
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
||||||
// Temporarily redirect output to capture response
|
// Temporarily redirect output to capture response
|
||||||
std::stringstream output;
|
std::stringstream output;
|
||||||
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
||||||
|
|
@ -236,25 +233,25 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
session.abs_pos++;
|
session.abs_pos++;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
session.abs_pos++;
|
session.abs_pos++;
|
||||||
|
|
||||||
// Check for EOS
|
// Check for EOS
|
||||||
if (state.gemma->Config().IsEOS(token)) {
|
if (state.gemma->Config().IsEOS(token)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode token
|
// Decode token
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
||||||
output << token_text;
|
output << token_text;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
||||||
*session.kv_cache, *state.env, timing_info);
|
*session.kv_cache, *state.env, timing_info);
|
||||||
|
|
||||||
// Create response
|
// Create response
|
||||||
json response = CreateAPIResponse(output.str(), false);
|
json response = CreateAPIResponse(output.str(), false);
|
||||||
response["usageMetadata"] = {
|
response["usageMetadata"] = {
|
||||||
|
|
@ -262,17 +259,22 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||||
{"totalTokenCount", session.abs_pos}
|
{"totalTokenCount", session.abs_pos}
|
||||||
};
|
};
|
||||||
|
|
||||||
res.set_content(response.dump(), "application/json");
|
res.set_content(response.dump(), "application/json");
|
||||||
|
|
||||||
} catch (const json::exception& e) {
|
} catch (const json::exception& e) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(),
|
res.set_content(
|
||||||
"application/json");
|
json{{"error",
|
||||||
|
{{"message", std::string("JSON parsing error: ") + e.what()}}}}
|
||||||
|
.dump(),
|
||||||
|
"application/json");
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
res.status = 500;
|
res.status = 500;
|
||||||
res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(),
|
res.set_content(
|
||||||
"application/json");
|
json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}
|
||||||
|
.dump(),
|
||||||
|
"application/json");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -280,11 +282,11 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
||||||
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||||
try {
|
try {
|
||||||
json request = json::parse(req.body);
|
json request = json::parse(req.body);
|
||||||
|
|
||||||
// Get or create session
|
// Get or create session
|
||||||
std::string session_id = request.value("sessionId", GenerateSessionId());
|
std::string session_id = request.value("sessionId", GenerateSessionId());
|
||||||
auto& session = state.GetOrCreateSession(session_id);
|
auto& session = state.GetOrCreateSession(session_id);
|
||||||
|
|
||||||
// Extract prompt from API format
|
// Extract prompt from API format
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
if (request.contains("contents")) {
|
if (request.contains("contents")) {
|
||||||
|
|
@ -294,13 +296,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up SSE headers
|
// Set up SSE headers
|
||||||
res.set_header("Content-Type", "text/event-stream");
|
res.set_header("Content-Type", "text/event-stream");
|
||||||
res.set_header("Cache-Control", "no-cache");
|
res.set_header("Cache-Control", "no-cache");
|
||||||
res.set_header("Connection", "keep-alive");
|
res.set_header("Connection", "keep-alive");
|
||||||
res.set_header("X-Session-Id", session_id);
|
res.set_header("X-Session-Id", session_id);
|
||||||
|
|
||||||
// Set up chunked content provider for SSE
|
// Set up chunked content provider for SSE
|
||||||
res.set_chunked_content_provider(
|
res.set_chunked_content_provider(
|
||||||
"text/event-stream",
|
"text/event-stream",
|
||||||
|
|
@ -309,18 +311,15 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
// Lock for inference
|
// Lock for inference
|
||||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||||
auto& session = state.GetOrCreateSession(session_id);
|
auto& session = state.GetOrCreateSession(session_id);
|
||||||
|
|
||||||
// Set up runtime config
|
// Set up runtime config
|
||||||
std::mt19937 gen;
|
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||||
RuntimeConfig runtime_config = ParseGenerationConfig(request, gen);
|
|
||||||
|
|
||||||
// Tokenize prompt
|
// Tokenize prompt
|
||||||
std::vector<int> tokens = WrapAndTokenize(state.gemma->Tokenizer(),
|
std::vector<int> tokens = WrapAndTokenize(
|
||||||
state.gemma->ChatTemplate(),
|
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
|
||||||
state.gemma->Config().wrapping,
|
state.gemma->Config().wrapping, session.abs_pos, prompt);
|
||||||
session.abs_pos,
|
|
||||||
prompt);
|
|
||||||
|
|
||||||
// Stream token callback
|
// Stream token callback
|
||||||
std::string accumulated_text;
|
std::string accumulated_text;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
|
|
@ -329,37 +328,38 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
session.abs_pos++;
|
session.abs_pos++;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
session.abs_pos++;
|
session.abs_pos++;
|
||||||
|
|
||||||
// Check for EOS
|
// Check for EOS
|
||||||
if (state.gemma->Config().IsEOS(token)) {
|
if (state.gemma->Config().IsEOS(token)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode token
|
// Decode token
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
||||||
accumulated_text += token_text;
|
accumulated_text += token_text;
|
||||||
|
|
||||||
// Send SSE event using unified formatter
|
// Send SSE event using unified formatter
|
||||||
json event = CreateAPIResponse(token_text, true);
|
json event = CreateAPIResponse(token_text, true);
|
||||||
|
|
||||||
std::string sse_data = "data: " + event.dump() + "\n\n";
|
std::string sse_data = "data: " + event.dump() + "\n\n";
|
||||||
sink.write(sse_data.data(), sse_data.size());
|
sink.write(sse_data.data(), sse_data.size());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
runtime_config.stream_token = stream_token;
|
runtime_config.stream_token = stream_token;
|
||||||
|
|
||||||
// Run inference with KV cache
|
// Run inference with KV cache
|
||||||
TimingInfo timing_info = {.verbosity = 0};
|
TimingInfo timing_info = {.verbosity = 0};
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
||||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
|
||||||
*session.kv_cache, *state.env, timing_info);
|
prefix_end, *session.kv_cache, *state.env,
|
||||||
|
timing_info);
|
||||||
|
|
||||||
// Send final event using unified formatter
|
// Send final event using unified formatter
|
||||||
json final_event = CreateAPIResponse("", false);
|
json final_event = CreateAPIResponse("", false);
|
||||||
final_event["usageMetadata"] = {
|
final_event["usageMetadata"] = {
|
||||||
|
|
@ -367,18 +367,18 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||||
{"totalTokenCount", session.abs_pos}
|
{"totalTokenCount", session.abs_pos}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
||||||
sink.write(final_sse.data(), final_sse.size());
|
sink.write(final_sse.data(), final_sse.size());
|
||||||
|
|
||||||
// Send done event
|
// Send done event
|
||||||
sink.write("data: [DONE]\n\n", 15);
|
sink.write("data: [DONE]\n\n", 15);
|
||||||
|
|
||||||
// Ensure all data is sent
|
// Ensure all data is sent
|
||||||
sink.done();
|
sink.done();
|
||||||
|
|
||||||
return false; // End streaming
|
return false; // End streaming
|
||||||
|
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
json error_event = {{"error", {{"message", e.what()}}}};
|
json error_event = {{"error", {{"message", e.what()}}}};
|
||||||
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
||||||
|
|
@ -387,11 +387,14 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
} catch (const json::exception& e) {
|
} catch (const json::exception& e) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(),
|
res.set_content(
|
||||||
"application/json");
|
json{{"error",
|
||||||
|
{{"message", std::string("JSON parsing error: ") + e.what()}}}}
|
||||||
|
.dump(),
|
||||||
|
"application/json");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -410,7 +413,7 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
||||||
{"topK", 1}
|
{"topK", 1}
|
||||||
}}}
|
}}}
|
||||||
};
|
};
|
||||||
|
|
||||||
res.set_content(response.dump(), "application/json");
|
res.set_content(response.dump(), "application/json");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -419,40 +422,40 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
||||||
// server_running = false;
|
// server_running = false;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
const InferenceArgs& inference) {
|
const InferenceArgs& inference) {
|
||||||
std::cerr << "Loading model..." << std::endl;
|
std::cerr << "Loading model..." << std::endl;
|
||||||
|
|
||||||
// Initialize model
|
// Initialize model
|
||||||
ThreadingContext ctx(threading);
|
ThreadingContext ctx(threading);
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
|
|
||||||
ServerState state;
|
ServerState state;
|
||||||
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
||||||
state.env = &env;
|
state.env = &env;
|
||||||
state.ctx = &ctx;
|
state.ctx = &ctx;
|
||||||
|
|
||||||
httplib::Server server;
|
httplib::Server server;
|
||||||
|
|
||||||
// Set up routes
|
// Set up routes
|
||||||
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
||||||
});
|
});
|
||||||
|
|
||||||
// API endpoints
|
// API endpoints
|
||||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
||||||
HandleListModels(state, inference, req, res);
|
HandleListModels(state, inference, req, res);
|
||||||
});
|
});
|
||||||
|
|
||||||
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
||||||
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||||
HandleGenerateContentNonStreaming(state, req, res);
|
HandleGenerateContentNonStreaming(state, req, res);
|
||||||
});
|
});
|
||||||
|
|
||||||
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||||
HandleGenerateContentStreaming(state, req, res);
|
HandleGenerateContentStreaming(state, req, res);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Periodic cleanup of old sessions
|
// Periodic cleanup of old sessions
|
||||||
std::thread cleanup_thread([&state]() {
|
std::thread cleanup_thread([&state]() {
|
||||||
while (server_running) {
|
while (server_running) {
|
||||||
|
|
@ -460,18 +463,18 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
state.CleanupOldSessions();
|
state.CleanupOldSessions();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
||||||
std::cerr << "Model loaded successfully" << std::endl;
|
std::cerr << "Model loaded successfully" << std::endl;
|
||||||
std::cerr << "Endpoints:" << std::endl;
|
std::cerr << "Endpoints:" << std::endl;
|
||||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
||||||
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
||||||
std::cerr << " GET /v1beta/models" << std::endl;
|
std::cerr << " GET /v1beta/models" << std::endl;
|
||||||
|
|
||||||
if (!server.listen("0.0.0.0", inference.port)) {
|
if (!server.listen("0.0.0.0", inference.port)) {
|
||||||
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup_thread.join();
|
cleanup_thread.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -479,11 +482,11 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::InternalInit();
|
gcpp::InternalInit();
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
gcpp::ThreadingArgs threading(argc, argv);
|
gcpp::ThreadingArgs threading(argc, argv);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
||||||
std::cout << "========================\n\n";
|
std::cout << "========================\n\n";
|
||||||
|
|
@ -501,14 +504,14 @@ int main(int argc, char** argv) {
|
||||||
std::cerr << "\n";
|
std::cerr << "\n";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Arguments are now handled by InferenceArgs
|
// Arguments are now handled by InferenceArgs
|
||||||
|
|
||||||
// // Set up signal handler
|
// // Set up signal handler
|
||||||
// signal(SIGINT, gcpp::HandleShutdown);
|
// signal(SIGINT, gcpp::HandleShutdown);
|
||||||
// signal(SIGTERM, gcpp::HandleShutdown);
|
// signal(SIGTERM, gcpp::HandleShutdown);
|
||||||
|
|
||||||
gcpp::RunServer(loader, threading, inference);
|
gcpp::RunServer(loader, threading, inference);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -155,8 +155,9 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||||
MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
|
const Logits logits(att, att_len);
|
||||||
Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
|
MaybeLogitsSoftCap(att_cap, logits, p, worker);
|
||||||
|
Softmax(logits, p, worker, /*temperature=*/1.0f);
|
||||||
|
|
||||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
||||||
worker);
|
worker);
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h" // InitGenerator
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||||
|
|
@ -135,8 +134,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
result_buffer.clear();
|
result_buffer.clear();
|
||||||
|
|
||||||
InitGenerator(inference_args, gen);
|
|
||||||
|
|
||||||
// Ensure we have an active conversation
|
// Ensure we have an active conversation
|
||||||
if (!active_conversation || !active_conversation->kv_cache) {
|
if (!active_conversation || !active_conversation->kv_cache) {
|
||||||
LogDebug("Generate called with null active_conversation or kv_cache");
|
LogDebug("Generate called with null active_conversation or kv_cache");
|
||||||
|
|
@ -174,8 +171,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
|
|
||||||
// set up runtime config
|
// set up runtime config
|
||||||
TimingInfo timing_info = {};
|
TimingInfo timing_info = {};
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.stream_token = stream_token,
|
||||||
.stream_token = stream_token,
|
|
||||||
.use_spinning = threading_args.spin};
|
.use_spinning = threading_args.spin};
|
||||||
inference_args.CopyTo(runtime_config);
|
inference_args.CopyTo(runtime_config);
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
@ -256,7 +252,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
// If not multiturn, or Paligemma (which handles turns differently),
|
// If not multiturn, or Paligemma (which handles turns differently),
|
||||||
// reset the *active* conversation's position.
|
// reset the *active* conversation's position.
|
||||||
active_conversation->abs_pos = 0;
|
active_conversation->abs_pos = 0;
|
||||||
InitGenerator(inference_args, gen);
|
|
||||||
} else {
|
} else {
|
||||||
// Multi-turn Gemma: Rewind position in the active conversation
|
// Multi-turn Gemma: Rewind position in the active conversation
|
||||||
// The last token was either EOS, then it should be ignored because it is
|
// The last token was either EOS, then it should be ignored because it is
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||||
|
|
||||||
#include <memory> // For std::shared_ptr, std::make_shared
|
#include <memory> // For std::shared_ptr, std::make_shared
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -107,10 +106,6 @@ class GemmaContext {
|
||||||
// Set deterministic flag
|
// Set deterministic flag
|
||||||
void SetDeterministic(bool value) {
|
void SetDeterministic(bool value) {
|
||||||
inference_args.deterministic = value;
|
inference_args.deterministic = value;
|
||||||
// Reset the random number generator for deterministic generation
|
|
||||||
if (value) {
|
|
||||||
gen.seed(0x87654321);
|
|
||||||
}
|
|
||||||
LogDebug("Setting deterministic flag to configured value");
|
LogDebug("Setting deterministic flag to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -289,9 +284,6 @@ class GemmaContext {
|
||||||
// Model itself (don't move this, needs to be below the args above)
|
// Model itself (don't move this, needs to be below the args above)
|
||||||
Gemma model;
|
Gemma model;
|
||||||
|
|
||||||
// Random generator (remains global for the context)
|
|
||||||
std::mt19937 gen;
|
|
||||||
|
|
||||||
// Static members for logging
|
// Static members for logging
|
||||||
static GemmaLogCallback s_log_callback;
|
static GemmaLogCallback s_log_callback;
|
||||||
static void* s_log_user_data;
|
static void* s_log_user_data;
|
||||||
|
|
|
||||||
|
|
@ -440,8 +440,7 @@ static void SampleAndStream(
|
||||||
|
|
||||||
// TODO: parallelize
|
// TODO: parallelize
|
||||||
non_eos.Foreach([&](size_t qi) {
|
non_eos.Foreach([&](size_t qi) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi));
|
||||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
|
||||||
|
|
||||||
// We streamed all prefill tokens, but pos is still one behind because we
|
// We streamed all prefill tokens, but pos is still one behind because we
|
||||||
// started generation at pos = prompt.size() - 1. We want the pos argument
|
// started generation at pos = prompt.size() - 1. We want the pos argument
|
||||||
|
|
@ -453,7 +452,8 @@ static void SampleAndStream(
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE SampleFunc
|
static HWY_INLINE SampleFunc
|
||||||
ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
||||||
|
const AesCtrEngine& engine, ThreadingContext& ctx) {
|
||||||
// If user provided a sample_func, use it.
|
// If user provided a sample_func, use it.
|
||||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||||
|
|
||||||
|
|
@ -462,27 +462,28 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
||||||
|
|
||||||
// Fast path for top-1 with no accept_token.
|
// Fast path for top-1 with no accept_token.
|
||||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
return Top1OfSoftmax(logits, vocab_size);
|
return Top1OfSoftmax(logits);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// General case: Softmax with top-k sampling.
|
// General case: Softmax with top-k sampling.
|
||||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE("Gen.Sample general");
|
PROFILER_ZONE("Gen.Sample general");
|
||||||
|
RngStream gen(engine, qi);
|
||||||
return FusedSoftmaxAndSampleTopK(
|
return FusedSoftmaxAndSampleTopK(
|
||||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
logits, runtime_config.top_k, gen, runtime_config.temperature,
|
||||||
runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
|
runtime_config.accept_token, ctx.profiler, worker);
|
||||||
worker);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode: generates one continuation token for each query in `qbatch`.
|
// Decode: generates one continuation token for each query in `qbatch`.
|
||||||
static void GenerateT(const ModelConfig& config,
|
static void GenerateT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const WeightsPtrs& weights, Activations& activations,
|
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||||
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
if (qbatch.MutablePos(qi) == 0) {
|
if (qbatch.MutablePos(qi) == 0) {
|
||||||
|
|
@ -554,7 +555,8 @@ static void GenerateT(const ModelConfig& config,
|
||||||
max_gen_steps = seq_len - max_prompt_size;
|
max_gen_steps = seq_len - max_prompt_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
|
const SampleFunc sample_token =
|
||||||
|
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||||
|
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
|
|
@ -568,15 +570,16 @@ static void GenerateT(const ModelConfig& config,
|
||||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const ModelConfig& config,
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const WeightsPtrs& weights, KVCache& kv_cache,
|
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
KVCache& kv_cache, MatMulEnv& env,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||||
|
|
||||||
AllQueries all_queries(prompt, pos, prefix_end,
|
AllQueries all_queries(prompt, pos, prefix_end,
|
||||||
hwy::Span<KVCache>(&kv_cache, 1));
|
hwy::Span<KVCache>(&kv_cache, 1));
|
||||||
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
|
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
|
||||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
||||||
timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -584,8 +587,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
// queries, and calls `GenerateT` on each batch.
|
// queries, and calls `GenerateT` on each batch.
|
||||||
void GenerateBatchT(const ModelConfig& config,
|
void GenerateBatchT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const WeightsPtrs& weights, AllQueries& all_queries,
|
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
AllQueries& all_queries, MatMulEnv& env,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||||
runtime_config.prefill_tbatch_size);
|
runtime_config.prefill_tbatch_size);
|
||||||
Activations activations(config, max_batch_size,
|
Activations activations(config, max_batch_size,
|
||||||
|
|
@ -596,7 +600,7 @@ void GenerateBatchT(const ModelConfig& config,
|
||||||
start += runtime_config.decode_qbatch_size) {
|
start += runtime_config.decode_qbatch_size) {
|
||||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
||||||
timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -637,7 +641,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||||
weights_(model_.Config()),
|
weights_(model_.Config()),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||||
inference_(inference) {
|
inference_(inference),
|
||||||
|
aes_ctr_engine_(inference.deterministic) {
|
||||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
||||||
mat_owners_, ctx);
|
mat_owners_, ctx);
|
||||||
|
|
@ -661,9 +666,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
TimingInfo& timing_info) const {
|
TimingInfo& timing_info) const {
|
||||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
|
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(
|
||||||
model_.Config(), runtime_config,
|
prompt, pos, prefix_end, model_.Config(), runtime_config, aes_ctr_engine_,
|
||||||
weights_, kv_cache, env, timing_info);
|
weights_, kv_cache, env, timing_info);
|
||||||
|
|
||||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
@ -674,7 +679,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
||||||
weights_, all_queries, env, timing_info);
|
aes_ctr_engine_, weights_, all_queries,
|
||||||
|
env, timing_info);
|
||||||
|
|
||||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -278,6 +278,7 @@ class Gemma {
|
||||||
WeightsPtrs::Mode weight_read_mode_;
|
WeightsPtrs::Mode weight_read_mode_;
|
||||||
GemmaChatTemplate chat_template_;
|
GemmaChatTemplate chat_template_;
|
||||||
InferenceArgs inference_;
|
InferenceArgs inference_;
|
||||||
|
AesCtrEngine aes_ctr_engine_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
|
|
@ -90,10 +89,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||||
// If not empty, AcceptFunc is called with token. It should return false for
|
// If not empty, AcceptFunc is called with token. It should return false for
|
||||||
// tokens you don't want to generate and true for tokens you want to generate.
|
// tokens you don't want to generate and true for tokens you want to generate.
|
||||||
using AcceptFunc = std::function<bool(int, float)>;
|
using AcceptFunc = std::function<bool(int, float)>;
|
||||||
// If not empty, SampleFunc is called with the logits for the next token, which
|
// If not empty, SampleFunc is called with the query_idx and logits for the
|
||||||
// it may modify/overwrite, and its return value is the next generated token
|
// next token, which it may modify/overwrite. It returns the next generated
|
||||||
// together with its probability.
|
// token together with its probability.
|
||||||
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
using SampleFunc = std::function<TokenAndProb(size_t, Logits)>;
|
||||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||||
// - index of query within containing batch (if any); zero otherwise.
|
// - index of query within containing batch (if any); zero otherwise.
|
||||||
// - position in the tokens sequence
|
// - position in the tokens sequence
|
||||||
|
|
@ -136,8 +135,7 @@ struct RuntimeConfig {
|
||||||
// Sampling-related parameters.
|
// Sampling-related parameters.
|
||||||
float temperature; // Temperature for sampling.
|
float temperature; // Temperature for sampling.
|
||||||
|
|
||||||
size_t top_k = 1; // Top-k for sampling.
|
size_t top_k = 1; // Top-k for sampling.
|
||||||
std::mt19937* gen; // Random number generator used for sampling.
|
|
||||||
|
|
||||||
int verbosity; // Controls verbosity of printed messages.
|
int verbosity; // Controls verbosity of printed messages.
|
||||||
|
|
||||||
|
|
|
||||||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -18,7 +18,6 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -98,9 +97,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
size_t prompt_size = 0;
|
size_t prompt_size = 0;
|
||||||
const ModelConfig& config = gemma.Config();
|
const ModelConfig& config = gemma.Config();
|
||||||
|
|
||||||
std::mt19937 gen;
|
|
||||||
InitGenerator(inference, gen);
|
|
||||||
|
|
||||||
const bool have_image = !inference.image_file.path.empty();
|
const bool have_image = !inference.image_file.path.empty();
|
||||||
Image image;
|
Image image;
|
||||||
const size_t pool_dim = config.vit_config.pool_dim;
|
const size_t pool_dim = config.vit_config.pool_dim;
|
||||||
|
|
@ -117,8 +113,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||||
const size_t image_size = config.vit_config.image_size;
|
const size_t image_size = config.vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||||
.verbosity = inference.verbosity,
|
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
double image_tokens_start = hwy::platform::Now();
|
double image_tokens_start = hwy::platform::Now();
|
||||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||||
|
|
@ -188,8 +183,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
|
|
||||||
// Set up runtime config.
|
// Set up runtime config.
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||||
.verbosity = inference.verbosity,
|
|
||||||
.batch_stream_token = batch_stream_token,
|
.batch_stream_token = batch_stream_token,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
inference.CopyTo(runtime_config);
|
inference.CopyTo(runtime_config);
|
||||||
|
|
@ -239,7 +233,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
// Prepare for the next turn. Works only for PaliGemma.
|
// Prepare for the next turn. Works only for PaliGemma.
|
||||||
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
|
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
abs_pos = 0; // Start a new turn at position 0.
|
abs_pos = 0; // Start a new turn at position 0.
|
||||||
InitGenerator(inference, gen);
|
|
||||||
} else {
|
} else {
|
||||||
// The last token was either EOS, then it should be ignored because it is
|
// The last token was either EOS, then it should be ignored because it is
|
||||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||||
|
|
|
||||||
|
|
@ -110,8 +110,7 @@ class VitAttention {
|
||||||
CallMatMul(Q, K, nullptr, env_, C);
|
CallMatMul(Q, K, nullptr, env_, C);
|
||||||
|
|
||||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||||
float* HWY_RESTRICT c = C.Row(task);
|
Softmax(C.RowSpan(task), env_.ctx.profiler, worker);
|
||||||
Softmax(c, C.Cols(), env_.ctx.profiler, worker);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||||
|
|
@ -154,7 +153,7 @@ class VitAttention {
|
||||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||||
}
|
}
|
||||||
// SoftMax yields "probabilities" in head_att.
|
// SoftMax yields "probabilities" in head_att.
|
||||||
Softmax(head_att, seq_len, env_.ctx.profiler, worker);
|
Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker);
|
||||||
// Compute weighted sum of v into att_out.
|
// Compute weighted sum of v into att_out.
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||||
|
|
|
||||||
|
|
@ -812,7 +812,7 @@ class DotStats {
|
||||||
|
|
||||||
// Forward relative error, lower is better.
|
// Forward relative error, lower is better.
|
||||||
void CheckRel() const {
|
void CheckRel() const {
|
||||||
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 4E-3);
|
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 7E-3);
|
||||||
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f);
|
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f);
|
||||||
|
|
||||||
// Compensated and Double are very accurate.
|
// Compensated and Double are very accurate.
|
||||||
|
|
@ -822,22 +822,22 @@ class DotStats {
|
||||||
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
|
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
|
||||||
|
|
||||||
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
||||||
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2);
|
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 3.5E-1);
|
||||||
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(),
|
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(),
|
||||||
0.072);
|
7.5E-2);
|
||||||
|
|
||||||
// Kahan (FastTwoSum) is decent:
|
// Kahan (FastTwoSum) is decent:
|
||||||
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3);
|
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 1E-2);
|
||||||
ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f);
|
ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f);
|
||||||
|
|
||||||
// TwoProducts and TwoSums are a bit better.
|
// TwoProducts and TwoSums are a bit better.
|
||||||
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(),
|
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(),
|
||||||
3E-3);
|
1.1E-2);
|
||||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f);
|
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 1.0f);
|
||||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(),
|
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(),
|
||||||
2.6E-3);
|
1.1E-2);
|
||||||
|
|
||||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
|
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 5.2E-2);
|
||||||
// Extremely high error on aarch64.
|
// Extremely high error on aarch64.
|
||||||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
|
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
|
||||||
}
|
}
|
||||||
|
|
@ -857,7 +857,7 @@ class DotStats {
|
||||||
ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f);
|
ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f);
|
||||||
|
|
||||||
// But TwoProducts/TwoSums help a bit.
|
// But TwoProducts/TwoSums help a bit.
|
||||||
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f);
|
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 1.0f);
|
||||||
ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f);
|
ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f);
|
||||||
|
|
||||||
// Extremely high error on aarch64.
|
// Extremely high error on aarch64.
|
||||||
|
|
@ -893,7 +893,7 @@ class DotStats {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns normalized value in [-1, 1).
|
// Returns normalized value in [-1, 1).
|
||||||
float RandomFloat(std::mt19937& rng) {
|
float RandomFloat(RngStream& rng) {
|
||||||
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
|
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
|
||||||
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
|
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
|
||||||
const uint32_t representation = exp | (rng() & mantissa_mask);
|
const uint32_t representation = exp | (rng() & mantissa_mask);
|
||||||
|
|
@ -908,7 +908,7 @@ float RandomFloat(std::mt19937& rng) {
|
||||||
// error from the Dot algorithms, not the compression.
|
// error from the Dot algorithms, not the compression.
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
||||||
std::mt19937& rng,
|
RngStream& rng,
|
||||||
const PackedSpan<Packed>& packed,
|
const PackedSpan<Packed>& packed,
|
||||||
CompressWorkingSet& work) {
|
CompressWorkingSet& work) {
|
||||||
std::uniform_int_distribution<int> e_dist(0, 6);
|
std::uniform_int_distribution<int> e_dist(0, 6);
|
||||||
|
|
@ -934,7 +934,7 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
||||||
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
|
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
|
||||||
template <typename WT, typename VT>
|
template <typename WT, typename VT>
|
||||||
double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
|
double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
|
||||||
std::mt19937& rng) {
|
RngStream& rng) {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
|
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
|
||||||
HWY_DASSERT(half != 0);
|
HWY_DASSERT(half != 0);
|
||||||
|
|
@ -1002,8 +1002,8 @@ struct TestShortDotsT {
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
ThreadingContext ctx(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
CompressWorkingSet work;
|
CompressWorkingSet work;
|
||||||
std::mt19937 rng;
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
rng.seed(12345);
|
RngStream rng(engine, 0);
|
||||||
|
|
||||||
hwy::Stats s_l1[kVariants];
|
hwy::Stats s_l1[kVariants];
|
||||||
|
|
||||||
|
|
@ -1108,9 +1108,10 @@ void TestAllDot() {
|
||||||
{ // ensure no profiler zones are active
|
{ // ensure no profiler zones are active
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
||||||
std::mt19937 rngs[kMaxWorkers];
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
|
RngStream rngs[kMaxWorkers];
|
||||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||||
rngs[i].seed(12345 + 65537 * i);
|
rngs[i] = RngStream(engine, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||||
|
|
|
||||||
150
ops/ops-inl.h
150
ops/ops-inl.h
|
|
@ -29,7 +29,7 @@
|
||||||
|
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h" // TokenAndProb
|
#include "util/basics.h" // TokenAndProb, RngStream
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -614,12 +614,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
}
|
}
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||||
hwy::Profiler& p, const size_t worker,
|
const size_t worker,
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f) {
|
||||||
static const auto zone = p.AddZone("Ops.Softmax");
|
static const auto zone = p.AddZone("Ops.Softmax");
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(logits.size() != 0);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -629,24 +629,25 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||||
V vmax = vmin;
|
V vmax = vmin;
|
||||||
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||||
hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
hn::Foreach(d, logits.data(), logits.size(), vmin,
|
||||||
*pmax = hn::Max(*pmax, value);
|
[pmax](const auto d, const V value)
|
||||||
});
|
HWY_ATTR { *pmax = hn::Max(*pmax, value); });
|
||||||
vmax = hn::MaxOfLanes(d, vmax);
|
vmax = hn::MaxOfLanes(d, vmax);
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR {
|
hn::Transform(d, logits.data(), logits.size(),
|
||||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
[pmax](const auto d, const V value) HWY_ATTR {
|
||||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
// Workaround for buggy SVE codegen: avoid inlined Exp().
|
||||||
} else {
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
} else {
|
||||||
}
|
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||||
});
|
}
|
||||||
|
});
|
||||||
|
|
||||||
if (temperature != 1.0f) {
|
if (temperature != 1.0f) {
|
||||||
const float temperature_inv = 1.0f / temperature;
|
const float temperature_inv = 1.0f / temperature;
|
||||||
hn::Transform(d, x, size,
|
hn::Transform(d, logits.data(), logits.size(),
|
||||||
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
||||||
return hn::Mul(value, hn::Set(d, temperature_inv));
|
return hn::Mul(value, hn::Set(d, temperature_inv));
|
||||||
});
|
});
|
||||||
|
|
@ -656,10 +657,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
// not make a huge difference. It halves the standard deviation of the sum of
|
// not make a huge difference. It halves the standard deviation of the sum of
|
||||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||||
// the generated text after a few hundred tokens.
|
// the generated text after a few hundred tokens.
|
||||||
const float sum_exp = Sum(d, x, size);
|
const float sum_exp = Sum(d, logits.data(), logits.size());
|
||||||
// Double-precision reciprocal does not appear to affect the results.
|
// Double-precision reciprocal does not appear to affect the results.
|
||||||
const float mul = 1.0f / sum_exp;
|
const float mul = 1.0f / sum_exp;
|
||||||
MulByConst(mul, x, size, p, worker);
|
MulByConst(mul, logits.data(), logits.size(), p, worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||||
|
|
@ -669,8 +670,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
// which already knows the max value which top-1 sampling would again seek.
|
// which already knows the max value which top-1 sampling would again seek.
|
||||||
|
|
||||||
// Returns the argmax and x[argmax].
|
// Returns the argmax and x[argmax].
|
||||||
static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
static HWY_INLINE TokenAndProb ArgmaxAndMax(Logits logits) {
|
||||||
const size_t num) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
using V = hn::Vec<D>;
|
using V = hn::Vec<D>;
|
||||||
|
|
@ -680,16 +680,16 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
||||||
using TI = hn::TFromD<decltype(di)>;
|
using TI = hn::TFromD<decltype(di)>;
|
||||||
using VI = hn::Vec<decltype(di)>;
|
using VI = hn::Vec<decltype(di)>;
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
HWY_ASSERT(num % (2 * N) == 0);
|
HWY_ASSERT(logits.size() % (2 * N) == 0);
|
||||||
|
|
||||||
V max0 = hn::Set(d, hwy::LowestValue<float>());
|
V max0 = hn::Set(d, hwy::LowestValue<float>());
|
||||||
V max1 = max0;
|
V max1 = max0;
|
||||||
VI argmax0 = hn::Zero(di);
|
VI argmax0 = hn::Zero(di);
|
||||||
VI argmax1 = argmax0;
|
VI argmax1 = argmax0;
|
||||||
|
|
||||||
for (size_t i = 0; i < num; i += 2 * N) {
|
for (size_t i = 0; i < logits.size(); i += 2 * N) {
|
||||||
const V v0 = hn::LoadU(d, x + i);
|
const V v0 = hn::LoadU(d, &logits[i]);
|
||||||
const V v1 = hn::LoadU(d, x + i + N);
|
const V v1 = hn::LoadU(d, &logits[i + N]);
|
||||||
const VI vi0 = hn::Iota(di, static_cast<TI>(i));
|
const VI vi0 = hn::Iota(di, static_cast<TI>(i));
|
||||||
const VI vi1 = hn::Iota(di, static_cast<TI>(i + N));
|
const VI vi1 = hn::Iota(di, static_cast<TI>(i + N));
|
||||||
const M gt0 = hn::Gt(v0, max0);
|
const M gt0 = hn::Gt(v0, max0);
|
||||||
|
|
@ -714,43 +714,43 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
||||||
return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)};
|
return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns argmax of softmax and its probability. This overwrites `x`, but not
|
// Returns argmax of softmax and its probability. This overwrites `logits`, but
|
||||||
// with normalized probabilities. Only equivalent to `Softmax` + `sample_func`
|
// not with normalized probabilities. Only equivalent to `Softmax` +
|
||||||
// if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize`
|
// `sample_func` if `kTopK` == 1. This is worthwhile because `logits.size()` is
|
||||||
// == 256K, and this avoids writing and then scanning again for the max.
|
// typically `kVocabSize == 256K`, and this avoids writing and then scanning
|
||||||
// However, this is not enough to make parallelization worthwhile.
|
// again for the max.
|
||||||
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
|
||||||
const size_t num) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
using V = hn::Vec<decltype(d)>;
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
|
||||||
const TokenAndProb argmax = ArgmaxAndMax(x, num);
|
const TokenAndProb argmax = ArgmaxAndMax(logits);
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
const V max = hn::Set(d, argmax.prob);
|
const V max = hn::Set(d, argmax.prob);
|
||||||
const V* pmax = &max;
|
const V* pmax = &max;
|
||||||
hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR {
|
hn::Transform(d, logits.data(), logits.size(),
|
||||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
[pmax](const auto d, const V value) HWY_ATTR {
|
||||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
// Temporary workaround for buggy SVE codegen: avoid inlined
|
||||||
} else {
|
// Exp().
|
||||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
}
|
} else {
|
||||||
});
|
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Normalize to a single probability. The exact sum seems like it should not
|
// Normalize to a single probability. The exact sum seems like it should not
|
||||||
// make a huge difference. It halves the standard deviation of the sum of the
|
// make a huge difference. It halves the standard deviation of the sum of the
|
||||||
// normalized probabilities from 1E-7 to 5E-8, but actually also changes the
|
// normalized probabilities from 1E-7 to 5E-8, but actually also changes the
|
||||||
// generated text after a few hundred tokens.
|
// generated text after a few hundred tokens.
|
||||||
const float sum_exp = Sum(d, x, num);
|
const float sum_exp = Sum(d, logits.data(), logits.size());
|
||||||
const float prob = x[argmax.token] / sum_exp;
|
const float prob = logits[argmax.token] / sum_exp;
|
||||||
return TokenAndProb{.token = argmax.token, .prob = prob};
|
return TokenAndProb{.token = argmax.token, .prob = prob};
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
|
||||||
const size_t size, hwy::Profiler& p,
|
hwy::Profiler& p, const size_t worker) {
|
||||||
const size_t worker) {
|
|
||||||
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
||||||
PROFILER_ZONE3(p, worker, zone);
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
|
|
@ -763,18 +763,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
const VF* HWY_RESTRICT pcap = &vcap;
|
const VF* HWY_RESTRICT pcap = &vcap;
|
||||||
const VF* HWY_RESTRICT pinv_cap = &vinv_cap;
|
const VF* HWY_RESTRICT pinv_cap = &vinv_cap;
|
||||||
|
|
||||||
DecompressAndCompressInplace(
|
DecompressAndCompressInplace(DF(), logits.data(), logits.size(),
|
||||||
DF(), x, size, [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF {
|
[pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF {
|
||||||
return hn::Mul(*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap)));
|
return hn::Mul(
|
||||||
});
|
*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap)));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls LogitsSoftCap if cap != 0.0f.
|
// Calls LogitsSoftCap if cap != 0.0f.
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||||
const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
const float cap, Logits logits, hwy::Profiler& p, const size_t worker) {
|
||||||
const size_t worker) {
|
|
||||||
if (cap != 0.0f) {
|
if (cap != 0.0f) {
|
||||||
LogitsSoftCap(cap, x, size, p, worker);
|
LogitsSoftCap(cap, logits, p, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -785,20 +785,18 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||||
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
||||||
[&](uint64_t task, size_t worker) {
|
[&](uint64_t task, size_t worker) {
|
||||||
if (non_eos.Get(task)) {
|
if (non_eos.Get(task)) {
|
||||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
|
LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker);
|
||||||
worker);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(Logits logits) {
|
||||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
|
||||||
size_t max_index = 0;
|
size_t max_index = 0;
|
||||||
float max_prob = probabilities[0];
|
float max_prob = logits[0];
|
||||||
for (size_t i = 1; i < vocab_size; ++i) {
|
for (size_t i = 1; i < logits.size(); ++i) {
|
||||||
if (probabilities[i] > max_prob) {
|
if (logits[i] > max_prob) {
|
||||||
max_index = i;
|
max_index = i;
|
||||||
max_prob = probabilities[i];
|
max_prob = logits[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return max_index;
|
return max_index;
|
||||||
|
|
@ -828,16 +826,15 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
|
||||||
|
|
||||||
template <typename TAcceptToken>
|
template <typename TAcceptToken>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
||||||
const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k,
|
Logits logits, size_t k, TAcceptToken& accept_token) {
|
||||||
TAcceptToken& accept_token) {
|
|
||||||
HWY_ASSERT(k != 0);
|
HWY_ASSERT(k != 0);
|
||||||
HWY_ASSERT(k <= vocab_size);
|
HWY_ASSERT(k <= logits.size());
|
||||||
std::vector<double> packed_token_probs;
|
std::vector<double> packed_token_probs;
|
||||||
for (int32_t i = 0; i < static_cast<int32_t>(vocab_size); ++i) {
|
for (int32_t i = 0; i < static_cast<int32_t>(logits.size()); ++i) {
|
||||||
if (accept_token && !accept_token(i, probabilities[i])) {
|
if (accept_token && !accept_token(i, logits[i])) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
|
packed_token_probs.push_back(PackTokenAndProb(i, logits[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
|
hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
|
||||||
|
|
@ -853,11 +850,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TAcceptToken>
|
template <typename TAcceptToken>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k,
|
||||||
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
|
RngStream& gen, float temperature,
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
TAcceptToken& accept_token) {
|
||||||
std::vector<TokenAndProb> token_probs =
|
std::vector<TokenAndProb> token_probs = TopK(logits, k, accept_token);
|
||||||
TopK(probabilities, vocab_size, k, accept_token);
|
|
||||||
std::vector<int> topk_indices(k);
|
std::vector<int> topk_indices(k);
|
||||||
std::vector<float> topk_probs(k);
|
std::vector<float> topk_probs(k);
|
||||||
for (size_t i = 0; i < k; ++i) {
|
for (size_t i = 0; i < k; ++i) {
|
||||||
|
|
@ -869,14 +865,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
|
|
||||||
template <typename TAcceptToken>
|
template <typename TAcceptToken>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
Logits logits, size_t k, RngStream& gen, float temperature,
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) {
|
||||||
hwy::Profiler& p, size_t worker) {
|
|
||||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||||
// avoids computing the softmax of all logits.
|
// avoids computing the softmax of all logits.
|
||||||
std::vector<TokenAndProb> token_logits =
|
std::vector<TokenAndProb> token_logits = TopK(logits, k, accept_token);
|
||||||
TopK(logits, vocab_size, k, accept_token);
|
|
||||||
std::vector<int> topk_indices(k);
|
std::vector<int> topk_indices(k);
|
||||||
std::vector<float> topk_logits(k);
|
std::vector<float> topk_logits(k);
|
||||||
for (size_t i = 0; i < token_logits.size(); ++i) {
|
for (size_t i = 0; i < token_logits.size(); ++i) {
|
||||||
|
|
@ -884,8 +878,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
topk_logits[i] = token_logits[i].prob;
|
topk_logits[i] = token_logits[i].prob;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t mask = token_logits.size();
|
const size_t mask = token_logits.size();
|
||||||
Softmax(topk_logits.data(), mask, p, worker, temperature);
|
Softmax(Logits(topk_logits.data(), mask), p, worker, temperature);
|
||||||
auto distribution = std::discrete_distribution<int>(
|
auto distribution = std::discrete_distribution<int>(
|
||||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||||
int topk_sampled_index = distribution(gen);
|
int topk_sampled_index = distribution(gen);
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,12 @@ namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
static RngStream MakeRng() {
|
||||||
|
static AesCtrEngine engine(/*deterministic=*/true);
|
||||||
|
static uint64_t stream = 0;
|
||||||
|
return RngStream(engine, ++stream);
|
||||||
|
}
|
||||||
|
|
||||||
template <class Test>
|
template <class Test>
|
||||||
struct ForeachCountAndMisalign {
|
struct ForeachCountAndMisalign {
|
||||||
template <typename T, class D>
|
template <typename T, class D>
|
||||||
|
|
@ -304,7 +310,7 @@ class TestSoftmax {
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleSoftmax(e, count);
|
SimpleSoftmax(e, count);
|
||||||
Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
|
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
|
@ -438,10 +444,9 @@ void TestRopeAndMulBy() {
|
||||||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||||
MatStorageT<float> x("x", dim_qkv, ctx.allocator);
|
MatStorageT<float> x("x", dim_qkv, ctx.allocator);
|
||||||
|
|
||||||
std::mt19937 gen;
|
RngStream rng = MakeRng();
|
||||||
gen.seed(0x12345678);
|
|
||||||
std::normal_distribution<float> r{0.0, 5.0};
|
std::normal_distribution<float> r{0.0, 5.0};
|
||||||
auto random_float = [&r, &gen] { return r(gen); };
|
auto random_float = [&r, &rng] { return r(rng); };
|
||||||
|
|
||||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
x.Row(0)[i] = random_float();
|
x.Row(0)[i] = random_float();
|
||||||
|
|
@ -704,38 +709,34 @@ void TestSampleTopK() {
|
||||||
hwy::Profiler& p = hwy::Profiler::Get();
|
hwy::Profiler& p = hwy::Profiler::Get();
|
||||||
const size_t worker = 0;
|
const size_t worker = 0;
|
||||||
const size_t kSize = 52;
|
const size_t kSize = 52;
|
||||||
std::vector<float> logits(kSize);
|
std::vector<float> logits_vec(kSize);
|
||||||
|
Logits logits(logits_vec.data(), kSize);
|
||||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||||
Softmax(logits.data(), kSize, p, worker);
|
Softmax(logits, p, worker);
|
||||||
std::mt19937 gen;
|
RngStream rng = MakeRng();
|
||||||
gen.seed(0x12345678);
|
|
||||||
float temperature = 1.0f;
|
float temperature = 1.0f;
|
||||||
// SampleTopK<1> should return the argmax.
|
// SampleTopK<1> should return the argmax.
|
||||||
std::function<bool(int, float)> accept_token;
|
std::function<bool(int, float)> accept_token;
|
||||||
int sample =
|
int sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token);
|
||||||
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
|
|
||||||
EXPECT_EQ(sample, 51); // Last is largest.
|
EXPECT_EQ(sample, 51); // Last is largest.
|
||||||
// Only accept even tokens, expect the last (largest) even index.
|
// Only accept even tokens, expect the last (largest) even index.
|
||||||
accept_token = [](int i, float) { return i % 2 == 0; };
|
accept_token = [](int i, float) { return i % 2 == 0; };
|
||||||
sample =
|
sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token);
|
||||||
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
|
|
||||||
EXPECT_EQ(sample, 50); // Last even index.
|
EXPECT_EQ(sample, 50); // Last even index.
|
||||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||||
Softmax(logits.data(), kSize, p, worker);
|
Softmax(logits, p, worker);
|
||||||
// Sample from the top 3, expect one of the top 3 even indices.
|
// Sample from the top 3, expect one of the top 3 even indices.
|
||||||
for (int i = 0; i < 100; ++i) {
|
for (int i = 0; i < 100; ++i) {
|
||||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
|
||||||
accept_token);
|
|
||||||
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
|
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
|
||||||
}
|
}
|
||||||
// Now set the temperature to 0.0f, which should always return the argmax,
|
// Now set the temperature to 0.0f, which should always return the argmax,
|
||||||
// even for k=3.
|
// even for k=3.
|
||||||
temperature = 0.0f;
|
temperature = 0.0f;
|
||||||
for (int i = 0; i < 100; ++i) {
|
for (int i = 0; i < 100; ++i) {
|
||||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
|
||||||
accept_token);
|
|
||||||
EXPECT_EQ(sample, 50);
|
EXPECT_EQ(sample, 50);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,42 +27,38 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
const size_t image_size = config.vit_config.image_size;
|
const size_t image_size = config.vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
|
RuntimeConfig runtime_config = {.verbosity = 0};
|
||||||
.verbosity = 0};
|
|
||||||
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
|
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
|
||||||
image, *image_tokens_, env_->MutableEnv());
|
image, *image_tokens_, env_->MutableEnv());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
||||||
const Gemma& model = *(env_->GetGemma());
|
const Gemma& model = *(env_->GetGemma());
|
||||||
env_->MutableGen().seed(0x12345678);
|
|
||||||
|
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
response += token_text;
|
||||||
response += token_text;
|
return true;
|
||||||
return true;
|
};
|
||||||
};
|
|
||||||
|
|
||||||
std::string mutable_prompt = prompt_text;
|
std::string mutable_prompt = prompt_text;
|
||||||
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
|
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
|
||||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||||
|
|
||||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||||
// PrefixLM sees/attends to all tokens.
|
// PrefixLM sees/attends to all tokens.
|
||||||
.prefill_tbatch_size = tokens.size(),
|
.prefill_tbatch_size = tokens.size(),
|
||||||
.gen = &env_->MutableGen(),
|
.verbosity = 0,
|
||||||
.verbosity = 0,
|
.stream_token = stream_token,
|
||||||
.stream_token = stream_token,
|
.image_tokens = image_tokens_.get()};
|
||||||
.image_tokens = image_tokens_.get()};
|
|
||||||
|
|
||||||
const size_t prefix_end = tokens.size();
|
const size_t prefix_end = tokens.size();
|
||||||
TimingInfo timing_info = {.verbosity = 0};
|
TimingInfo timing_info = {.verbosity = 0};
|
||||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||||
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
|
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,8 @@ class GemmaModel {
|
||||||
// Generates a single example, given a prompt and a callback to stream the
|
// Generates a single example, given a prompt and a callback to stream the
|
||||||
// generated tokens.
|
// generated tokens.
|
||||||
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
|
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
|
||||||
size_t max_generated_tokens, float temperature, float seed,
|
size_t max_generated_tokens, float temperature,
|
||||||
gcpp::AcceptFunc accept, bool skip_prompt) {
|
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
|
||||||
env_.MutableGen().seed(seed);
|
|
||||||
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
|
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
|
||||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||||
config.max_generated_tokens = max_generated_tokens;
|
config.max_generated_tokens = max_generated_tokens;
|
||||||
|
|
@ -77,7 +76,7 @@ class GemmaModel {
|
||||||
|
|
||||||
// Generates a single example, given a prompt, and returns the result.
|
// Generates a single example, given a prompt, and returns the result.
|
||||||
std::string Generate(std::string prompt, size_t max_generated_tokens,
|
std::string Generate(std::string prompt, size_t max_generated_tokens,
|
||||||
float temperature, float seed,
|
float temperature, float /*seed*/,
|
||||||
const std::vector<std::string>& accept,
|
const std::vector<std::string>& accept,
|
||||||
const std::vector<std::string>& end) {
|
const std::vector<std::string>& end) {
|
||||||
std::set<int> end_token_set{};
|
std::set<int> end_token_set{};
|
||||||
|
|
@ -124,7 +123,6 @@ class GemmaModel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
env_.MutableGen().seed(seed);
|
|
||||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||||
config.max_generated_tokens = max_generated_tokens;
|
config.max_generated_tokens = max_generated_tokens;
|
||||||
config.temperature = temperature;
|
config.temperature = temperature;
|
||||||
|
|
@ -144,14 +142,13 @@ class GemmaModel {
|
||||||
// results.
|
// results.
|
||||||
std::vector<std::string> GenerateBatch(const std::vector<std::string>& inputs,
|
std::vector<std::string> GenerateBatch(const std::vector<std::string>& inputs,
|
||||||
size_t max_generated_tokens,
|
size_t max_generated_tokens,
|
||||||
float temperature, float seed,
|
float temperature, float /*seed*/,
|
||||||
size_t top_k) {
|
size_t top_k) {
|
||||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||||
config.max_generated_tokens = max_generated_tokens;
|
config.max_generated_tokens = max_generated_tokens;
|
||||||
config.temperature = temperature;
|
config.temperature = temperature;
|
||||||
config.top_k = top_k;
|
config.top_k = top_k;
|
||||||
config.verbosity = 0;
|
config.verbosity = 0;
|
||||||
env_.MutableGen().seed(seed);
|
|
||||||
|
|
||||||
std::vector<gcpp::QueryResult> outputs = env_.BatchQueryModel(inputs);
|
std::vector<gcpp::QueryResult> outputs = env_.BatchQueryModel(inputs);
|
||||||
std::vector<std::string> result;
|
std::vector<std::string> result;
|
||||||
|
|
@ -187,8 +184,7 @@ class GemmaModel {
|
||||||
"image_tokens",
|
"image_tokens",
|
||||||
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||||
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
|
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
|
||||||
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
|
gcpp::RuntimeConfig runtime_config = {.verbosity = 0};
|
||||||
.verbosity = 0};
|
|
||||||
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
||||||
c_image, *image_tokens_, env_.MutableEnv());
|
c_image, *image_tokens_, env_.MutableEnv());
|
||||||
}
|
}
|
||||||
|
|
@ -197,10 +193,9 @@ class GemmaModel {
|
||||||
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
|
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
|
||||||
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
||||||
std::string prompt, size_t max_generated_tokens, float temperature,
|
std::string prompt, size_t max_generated_tokens, float temperature,
|
||||||
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||||
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
||||||
const gcpp::Gemma& model = *env_.GetGemma();
|
const gcpp::Gemma& model = *env_.GetGemma();
|
||||||
env_.MutableGen().seed(seed);
|
|
||||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||||
config.max_generated_tokens = max_generated_tokens;
|
config.max_generated_tokens = max_generated_tokens;
|
||||||
config.temperature = temperature;
|
config.temperature = temperature;
|
||||||
|
|
@ -273,6 +268,7 @@ PYBIND11_MODULE(gemma, mod) {
|
||||||
}),
|
}),
|
||||||
py::arg("tokenizer_path"), py::arg("weights_path"),
|
py::arg("tokenizer_path"), py::arg("weights_path"),
|
||||||
py::arg("max_threads") = 0)
|
py::arg("max_threads") = 0)
|
||||||
|
// seed arguments are ignored.
|
||||||
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
|
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
|
||||||
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
|
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
|
||||||
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
|
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
RNG::RNG(bool deterministic) {
|
AesCtrEngine::AesCtrEngine(bool deterministic) {
|
||||||
// Pi-based nothing up my sleeve numbers from Randen.
|
// Pi-based nothing up my sleeve numbers from Randen.
|
||||||
key_[0] = 0x243F6A8885A308D3ull;
|
key_[0] = 0x243F6A8885A308D3ull;
|
||||||
key_[1] = 0x13198A2E03707344ull;
|
key_[1] = 0x13198A2E03707344ull;
|
||||||
|
|
@ -54,9 +54,10 @@ static V Load(const uint64_t* ptr) {
|
||||||
return hn::Load(D(), reinterpret_cast<const uint8_t*>(ptr));
|
return hn::Load(D(), reinterpret_cast<const uint8_t*>(ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
RNG::result_type RNG::operator()() {
|
uint64_t AesCtrEngine::operator()(uint64_t stream, uint64_t counter) const {
|
||||||
V state = Load(counter_);
|
const hn::Repartition<uint64_t, D> d64;
|
||||||
counter_[0]++;
|
|
||||||
|
V state = hn::BitCast(D(), hn::Dup128VecFromValues(d64, counter, stream));
|
||||||
state = hn::Xor(state, Load(key_)); // initial whitening
|
state = hn::Xor(state, Load(key_)); // initial whitening
|
||||||
|
|
||||||
static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t));
|
static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t));
|
||||||
|
|
@ -68,7 +69,6 @@ RNG::result_type RNG::operator()() {
|
||||||
state = hn::AESRound(state, Load(key_ + 10));
|
state = hn::AESRound(state, Load(key_ + 10));
|
||||||
|
|
||||||
// Return lower 64 bits of the u8 vector.
|
// Return lower 64 bits of the u8 vector.
|
||||||
const hn::Repartition<uint64_t, D> d64;
|
|
||||||
return hn::GetLane(hn::BitCast(d64, state));
|
return hn::GetLane(hn::BitCast(d64, state));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
|
@ -120,39 +120,60 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end,
|
||||||
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using Logits = hwy::Span<float>; // size() is vocab_size.
|
||||||
|
|
||||||
// Non-cryptographic 64-bit pseudo-random number generator. Supports random or
|
// Non-cryptographic 64-bit pseudo-random number generator. Supports random or
|
||||||
// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`.
|
// deterministic seeding.
|
||||||
//
|
//
|
||||||
// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This
|
// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This
|
||||||
// is useful for parallel sampling. Each thread can generate the stream for a
|
// is useful for parallel sampling. Each thread can generate the stream for a
|
||||||
// particular task, without caring about prior/subsequent generations.
|
// particular task, without caring about prior/subsequent generations.
|
||||||
class alignas(16) RNG {
|
class alignas(16) AesCtrEngine {
|
||||||
// "Large-scale randomness study of security margins for 100+ cryptographic
|
// "Large-scale randomness study of security margins for 100+ cryptographic
|
||||||
// functions": at least four.
|
// functions": at least four.
|
||||||
// "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant.
|
// "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant.
|
||||||
static constexpr size_t kRounds = 5;
|
static constexpr size_t kRounds = 5;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit RNG(bool deterministic);
|
// If `deterministic` is true, uses a fixed seed; otherwise, attempts to
|
||||||
|
// grab entropy from the OS.
|
||||||
|
explicit AesCtrEngine(bool deterministic);
|
||||||
|
|
||||||
void SetStream(uint64_t stream) {
|
// Pure and thread safe; typically called via `RngStream`, which increments
|
||||||
counter_[1] = stream;
|
// `counter`. Throughput is about 100M/s on 3 GHz Skylake. It could be
|
||||||
counter_[0] = 0;
|
// increased 4x via unrolling by the AES latency (4-7 cycles), but because
|
||||||
}
|
// users generally call once at a time, this requires buffering, which is not
|
||||||
|
// worth the complexity in this application.
|
||||||
|
uint64_t operator()(uint64_t stream, uint64_t counter) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t key_[2 * (1 + kRounds)];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Flyweight per-thread adapter that maintains the counter. Conforms to C++
|
||||||
|
// `UniformRandomBitGenerator`.
|
||||||
|
class RngStream {
|
||||||
|
public:
|
||||||
|
RngStream() = default; // Allow C arrays with subsequent initialization.
|
||||||
|
|
||||||
|
// Binds to an engine, which holds the seed and must outlive this object.
|
||||||
|
// Sets the stream; any other `RngStream` with the same `counter_rng` and
|
||||||
|
// `stream` will return the same sequence. This is typically the task ID, so
|
||||||
|
// that threads can independently generate values for each task.
|
||||||
|
RngStream(const AesCtrEngine& counter_rng, uint64_t stream)
|
||||||
|
: engine_(&counter_rng), stream_(stream), counter_(0) {}
|
||||||
|
|
||||||
using result_type = uint64_t;
|
using result_type = uint64_t;
|
||||||
static constexpr result_type min() { return 0; }
|
static constexpr result_type min() { return 0; }
|
||||||
static constexpr result_type max() { return ~result_type{0}; }
|
static constexpr result_type max() { return ~result_type{0}; }
|
||||||
|
result_type operator()() { return (*engine_)(stream_, counter_++); }
|
||||||
// About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via
|
|
||||||
// unrolling by the AES latency (4-7 cycles). `std::discrete_distribution`
|
|
||||||
// makes individual calls to the generator, which would require buffering,
|
|
||||||
// which is not worth the complexity.
|
|
||||||
result_type operator()();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint64_t counter_[2] = {};
|
const AesCtrEngine* engine_ = nullptr;
|
||||||
uint64_t key_[2 * (1 + kRounds)];
|
uint64_t stream_ = 0; // immutable after ctor
|
||||||
|
uint64_t counter_ = 0;
|
||||||
|
// Prevent false sharing if used by multiple threads.
|
||||||
|
HWY_MAYBE_UNUSED uint8_t padding_[HWY_ALIGNMENT - 16 - sizeof(engine_)];
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,11 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(BasicsTest, IsDeterministic) {
|
TEST(BasicsTest, EngineIsDeterministic) {
|
||||||
RNG rng1(/*deterministic=*/true);
|
const AesCtrEngine engine1(/*deterministic=*/true);
|
||||||
RNG rng2(/*deterministic=*/true);
|
const AesCtrEngine engine2(/*deterministic=*/true);
|
||||||
|
RngStream rng1(engine1, 0);
|
||||||
|
RngStream rng2(engine2, 0);
|
||||||
// Remember for later testing after resetting the stream.
|
// Remember for later testing after resetting the stream.
|
||||||
const uint64_t r0 = rng1();
|
const uint64_t r0 = rng1();
|
||||||
const uint64_t r1 = rng1();
|
const uint64_t r1 = rng1();
|
||||||
|
|
@ -42,15 +44,17 @@ TEST(BasicsTest, IsDeterministic) {
|
||||||
HWY_ASSERT(rng1() == rng2());
|
HWY_ASSERT(rng1() == rng2());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset counter, ensure it matches the default-constructed RNG.
|
// Reset counter, ensure it matches the prior sequence.
|
||||||
rng1.SetStream(0);
|
rng1 = RngStream(engine1, 0);
|
||||||
HWY_ASSERT(r0 == rng1());
|
HWY_ASSERT(r0 == rng1());
|
||||||
HWY_ASSERT(r1 == rng1());
|
HWY_ASSERT(r1 == rng1());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BasicsTest, IsSeeded) {
|
TEST(BasicsTest, EngineIsSeeded) {
|
||||||
RNG rng1(/*deterministic=*/true);
|
AesCtrEngine engine1(/*deterministic=*/true);
|
||||||
RNG rng2(/*deterministic=*/false);
|
AesCtrEngine engine2(/*deterministic=*/false);
|
||||||
|
RngStream rng1(engine1, 0);
|
||||||
|
RngStream rng2(engine2, 0);
|
||||||
// It would be very unlucky to have even one 64-bit value match, and two are
|
// It would be very unlucky to have even one 64-bit value match, and two are
|
||||||
// extremely unlikely.
|
// extremely unlikely.
|
||||||
const uint64_t a0 = rng1();
|
const uint64_t a0 = rng1();
|
||||||
|
|
@ -60,9 +64,27 @@ TEST(BasicsTest, IsSeeded) {
|
||||||
HWY_ASSERT(a0 != b0 || a1 != b1);
|
HWY_ASSERT(a0 != b0 || a1 != b1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BasicsTest, StreamsDiffer) {
|
||||||
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
|
// Compare random streams for more coverage than just the first N streams.
|
||||||
|
RngStream rng_for_stream(engine, 0);
|
||||||
|
for (size_t i = 0; i < 1000; ++i) {
|
||||||
|
RngStream rng1(engine, rng_for_stream());
|
||||||
|
RngStream rng2(engine, rng_for_stream());
|
||||||
|
// It would be very unlucky to have even one 64-bit value match, and two are
|
||||||
|
// extremely unlikely.
|
||||||
|
const uint64_t a0 = rng1();
|
||||||
|
const uint64_t a1 = rng1();
|
||||||
|
const uint64_t b0 = rng2();
|
||||||
|
const uint64_t b1 = rng2();
|
||||||
|
HWY_ASSERT(a0 != b0 || a1 != b1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If not close to 50% 1-bits, the RNG is quite broken.
|
// If not close to 50% 1-bits, the RNG is quite broken.
|
||||||
TEST(BasicsTest, BitDistribution) {
|
TEST(BasicsTest, BitDistribution) {
|
||||||
RNG rng(/*deterministic=*/true);
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
|
RngStream rng(engine, 0);
|
||||||
constexpr size_t kU64 = 2 * 1000 * 1000;
|
constexpr size_t kU64 = 2 * 1000 * 1000;
|
||||||
const hwy::Timestamp t0;
|
const hwy::Timestamp t0;
|
||||||
uint64_t one_bits = 0;
|
uint64_t one_bits = 0;
|
||||||
|
|
@ -78,7 +100,8 @@ TEST(BasicsTest, BitDistribution) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BasicsTest, ChiSquared) {
|
TEST(BasicsTest, ChiSquared) {
|
||||||
RNG rng(/*deterministic=*/true);
|
AesCtrEngine engine(/*deterministic=*/true);
|
||||||
|
RngStream rng(engine, 0);
|
||||||
constexpr size_t kU64 = 1 * 1000 * 1000;
|
constexpr size_t kU64 = 1 * 1000 * 1000;
|
||||||
|
|
||||||
// Test each byte separately.
|
// Test each byte separately.
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,13 @@ class MatPtrT : public MatPtr {
|
||||||
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
|
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hwy::Span<MatT> RowSpan(size_t row) {
|
||||||
|
return hwy::Span<MatT>(Row(row), Cols());
|
||||||
|
}
|
||||||
|
hwy::Span<const MatT> RowSpan(size_t row) const {
|
||||||
|
return hwy::Span<const MatT>(Row(row), Cols());
|
||||||
|
}
|
||||||
|
|
||||||
PackedSpan<const MatT> PaddedSpan() const {
|
PackedSpan<const MatT> PaddedSpan() const {
|
||||||
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
|
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue