mirror of https://github.com/google/gemma.cpp.git
Replace mt19937 with new generator to enable parallel sampling
Split it into immutable AesCtrEngine and RngStream Also add RowSpan and Logits span PiperOrigin-RevId: 803336423
This commit is contained in:
parent
5d1693e806
commit
56186193c1
|
|
@ -24,7 +24,6 @@
|
|||
|
||||
#include <algorithm> // std::shuffle
|
||||
#include <array>
|
||||
#include <random>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
#include "util/test_util.h"
|
||||
|
|
@ -104,8 +103,8 @@ struct TestPlateaus {
|
|||
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
|
||||
}
|
||||
|
||||
std::random_device rd; // NOLINT
|
||||
std::mt19937 rng(rd());
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
RngStream rng(engine, 0);
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
NuqStream::ClusterBuf buf;
|
||||
|
|
@ -151,8 +150,8 @@ struct TestRamp {
|
|||
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
|
||||
}
|
||||
|
||||
std::random_device rd; // NOLINT
|
||||
std::mt19937 rng(rd());
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
RngStream rng(engine, 0);
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
NuqStream::ClusterBuf buf;
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -37,17 +36,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||
if (inference.deterministic) {
|
||||
// Nothing up my sleeve number, at least some upper bits set.
|
||||
gen.seed(0x12345678);
|
||||
} else {
|
||||
// Depending on the library implementation, this may still be deterministic.
|
||||
std::random_device rd; // NOLINT
|
||||
gen.seed(rd());
|
||||
}
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||
|
|
@ -60,12 +48,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
ctx_);
|
||||
}
|
||||
|
||||
InitGenerator(inference, gen_);
|
||||
|
||||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.gen = &gen_,
|
||||
.verbosity = inference.verbosity,
|
||||
};
|
||||
inference.CopyTo(runtime_config_);
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -32,8 +31,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
||||
|
||||
// Return type for query model calls.
|
||||
struct QueryResult {
|
||||
std::string response;
|
||||
|
|
@ -107,7 +104,6 @@ class GemmaEnv {
|
|||
|
||||
int Verbosity() const { return runtime_config_.verbosity; }
|
||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
std::mt19937& MutableGen() { return gen_; }
|
||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||
MatMulEnv& MutableEnv() { return env_; }
|
||||
|
||||
|
|
@ -115,7 +111,6 @@ class GemmaEnv {
|
|||
ThreadingContext ctx_;
|
||||
MatMulEnv env_;
|
||||
Gemma gemma_;
|
||||
std::mt19937 gen_; // Random number generator.
|
||||
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -56,11 +56,10 @@ static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
|||
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||
}
|
||||
|
||||
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
|
||||
size_t k) {
|
||||
std::vector<std::pair<float, int>> sorted(len);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
||||
void LogTopK(const GemmaTokenizer& tokenizer, Logits logits, size_t k) {
|
||||
std::vector<std::pair<float, int>> sorted(logits.size());
|
||||
for (size_t i = 0; i < logits.size(); ++i) {
|
||||
sorted[i] = std::make_pair(logits[i], static_cast<int>(i));
|
||||
}
|
||||
std::sort(sorted.begin(), sorted.end(),
|
||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||
|
|
@ -84,9 +83,8 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
|
||||
hwy::Profiler& p) {
|
||||
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
|
||||
void CallSoftmax(Logits logits, hwy::Profiler& p) {
|
||||
Softmax(logits, p, hwy::Profiler::Thread());
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -107,19 +105,19 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||
size_t pos = 1;
|
||||
|
||||
const SampleFunc sample_token = [&](float* probs,
|
||||
size_t vocab_size) -> TokenAndProb {
|
||||
const SampleFunc sample_token = [&](size_t qi,
|
||||
Logits logits) -> TokenAndProb {
|
||||
// input is logits, not yet probabilities
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
|
||||
// We are called for each token, but pos starts at 1. Clamping
|
||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||
HWY_ASSERT(pos < prompt.size());
|
||||
const int token = prompt[pos];
|
||||
const float prob = probs[token];
|
||||
const float prob = logits[token];
|
||||
cross_entropy -= std::max(std::log(prob), -64.0f);
|
||||
|
||||
if (verbosity >= 4) {
|
||||
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
|
||||
LogTopK(gemma.Tokenizer(), logits, 10);
|
||||
}
|
||||
if (verbosity >= 3) {
|
||||
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
||||
|
|
@ -139,7 +137,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
RuntimeConfig runtime = {
|
||||
.max_generated_tokens = max_generated_tokens - 1,
|
||||
.temperature = 0.0f,
|
||||
.gen = nullptr,
|
||||
.verbosity = verbosity,
|
||||
.stream_token = stream_token,
|
||||
.sample_func = sample_token,
|
||||
|
|
|
|||
|
|
@ -115,7 +115,6 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
RuntimeConfig runtime_config{
|
||||
.max_generated_tokens = 64,
|
||||
.temperature = 0.0f,
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 2,
|
||||
.batch_stream_token = stream_token,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -126,7 +126,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_generated_tokens = 30,
|
||||
.temperature = 0.0f,
|
||||
.gen = &env.MutableGen(),
|
||||
.verbosity = env.Verbosity(),
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -44,7 +44,7 @@ int main(int argc, char** argv) {
|
|||
for (int arg = 0; arg < argc; ++arg) {
|
||||
// Find a --reject flag and consume everything after it.
|
||||
if (strcmp(argv[arg], "--reject") == 0) {
|
||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -55,11 +55,6 @@ int main(int argc, char** argv) {
|
|||
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
std::mt19937 gen;
|
||||
std::random_device rd; // NOLINT
|
||||
gen.seed(rd());
|
||||
|
||||
// Tokenize instructions.
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
const std::vector<int> tokens =
|
||||
|
|
@ -84,7 +79,6 @@ int main(int argc, char** argv) {
|
|||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_generated_tokens = 1024,
|
||||
.temperature = 1.0,
|
||||
.gen = &gen,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.accept_token =
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@
|
|||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -38,11 +37,7 @@ class SimplifiedGemma {
|
|||
: ctx_(threading),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_),
|
||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
gen_.seed(rd());
|
||||
}
|
||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
||||
|
|
@ -76,7 +71,6 @@ class SimplifiedGemma {
|
|||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_generated_tokens = max_generated_tokens,
|
||||
.temperature = temperature,
|
||||
.gen = &gen_,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.accept_token =
|
||||
|
|
@ -93,6 +87,5 @@ class SimplifiedGemma {
|
|||
gcpp::MatMulEnv env_;
|
||||
gcpp::Gemma gemma_;
|
||||
gcpp::KVCache kv_cache_;
|
||||
std::mt19937 gen_;
|
||||
std::string validation_error_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -60,18 +60,18 @@ struct ServerState {
|
|||
std::unique_ptr<Gemma> gemma;
|
||||
MatMulEnv* env;
|
||||
ThreadingContext* ctx;
|
||||
|
||||
|
||||
// Session-based KV cache storage
|
||||
struct Session {
|
||||
std::unique_ptr<KVCache> kv_cache;
|
||||
size_t abs_pos = 0;
|
||||
std::chrono::steady_clock::time_point last_access;
|
||||
};
|
||||
|
||||
|
||||
std::unordered_map<std::string, Session> sessions;
|
||||
std::mutex sessions_mutex;
|
||||
std::mutex inference_mutex;
|
||||
|
||||
|
||||
// Cleanup old sessions after 30 minutes of inactivity
|
||||
void CleanupOldSessions() {
|
||||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||
|
|
@ -84,7 +84,7 @@ struct ServerState {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get or create session with KV cache
|
||||
Session& GetOrCreateSession(const std::string& session_id) {
|
||||
std::lock_guard<std::mutex> lock(sessions_mutex);
|
||||
|
|
@ -101,24 +101,25 @@ struct ServerState {
|
|||
std::string GenerateSessionId() {
|
||||
static std::atomic<uint64_t> counter{0};
|
||||
std::stringstream ss;
|
||||
ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count()
|
||||
<< "_" << counter.fetch_add(1);
|
||||
ss << "session_" << std::hex
|
||||
<< std::chrono::steady_clock::now().time_since_epoch().count() << "_"
|
||||
<< counter.fetch_add(1);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// Wraps messages with start_of_turn markers - handles both with and without roles
|
||||
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
||||
std::string prompt;
|
||||
|
||||
|
||||
for (const auto& content : contents) {
|
||||
if (content.contains("parts")) {
|
||||
// Check if role is specified (public API format) or not (local format)
|
||||
std::string role = content.value("role", "");
|
||||
|
||||
|
||||
for (const auto& part : content["parts"]) {
|
||||
if (part.contains("text")) {
|
||||
std::string text = part["text"];
|
||||
|
||||
|
||||
if (role == "user") {
|
||||
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
||||
} else if (role == "model") {
|
||||
|
|
@ -131,24 +132,23 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Parse generation config
|
||||
RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) {
|
||||
RuntimeConfig ParseGenerationConfig(const json& request) {
|
||||
RuntimeConfig config;
|
||||
config.gen = &gen;
|
||||
config.verbosity = 0;
|
||||
|
||||
|
||||
// Set defaults matching public API
|
||||
config.temperature = 1.0f;
|
||||
config.top_k = 1;
|
||||
config.max_generated_tokens = 8192;
|
||||
|
||||
|
||||
if (request.contains("generationConfig")) {
|
||||
auto& gen_config = request["generationConfig"];
|
||||
|
||||
|
||||
if (gen_config.contains("temperature")) {
|
||||
config.temperature = gen_config["temperature"].get<float>();
|
||||
}
|
||||
|
|
@ -159,7 +159,7 @@ RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) {
|
|||
config.max_generated_tokens = gen_config["maxOutputTokens"].get<size_t>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -175,12 +175,12 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
|||
}}},
|
||||
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
||||
};
|
||||
|
||||
|
||||
// Only add finishReason for non-streaming chunks
|
||||
if (!is_streaming_chunk) {
|
||||
response["candidates"][0]["finishReason"] = "STOP";
|
||||
}
|
||||
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
|
|
@ -188,11 +188,11 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false)
|
|||
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
json request = json::parse(req.body);
|
||||
|
||||
|
||||
// Get or create session
|
||||
std::string session_id = request.value("sessionId", GenerateSessionId());
|
||||
auto& session = state.GetOrCreateSession(session_id);
|
||||
|
||||
|
||||
// Extract prompt from API format
|
||||
std::string prompt;
|
||||
if (request.contains("contents")) {
|
||||
|
|
@ -202,32 +202,29 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Lock for inference
|
||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||
|
||||
|
||||
// Set up runtime config
|
||||
std::mt19937 gen;
|
||||
RuntimeConfig runtime_config = ParseGenerationConfig(request, gen);
|
||||
|
||||
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||
|
||||
// Collect full response
|
||||
std::string full_response;
|
||||
runtime_config.stream_token = [&full_response](int token, float) {
|
||||
// Skip EOS token
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
// Tokenize prompt
|
||||
std::vector<int> tokens = WrapAndTokenize(state.gemma->Tokenizer(),
|
||||
state.gemma->ChatTemplate(),
|
||||
state.gemma->Config().wrapping,
|
||||
session.abs_pos,
|
||||
prompt);
|
||||
|
||||
std::vector<int> tokens = WrapAndTokenize(
|
||||
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
|
||||
state.gemma->Config().wrapping, session.abs_pos, prompt);
|
||||
|
||||
// Run inference with KV cache
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
size_t prefix_end = 0;
|
||||
|
||||
|
||||
// Temporarily redirect output to capture response
|
||||
std::stringstream output;
|
||||
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
||||
|
|
@ -236,25 +233,25 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
session.abs_pos++;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
session.abs_pos++;
|
||||
|
||||
|
||||
// Check for EOS
|
||||
if (state.gemma->Config().IsEOS(token)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// Decode token
|
||||
std::string token_text;
|
||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
||||
output << token_text;
|
||||
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
||||
*session.kv_cache, *state.env, timing_info);
|
||||
|
||||
|
||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
||||
*session.kv_cache, *state.env, timing_info);
|
||||
|
||||
// Create response
|
||||
json response = CreateAPIResponse(output.str(), false);
|
||||
response["usageMetadata"] = {
|
||||
|
|
@ -262,17 +259,22 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||
{"totalTokenCount", session.abs_pos}
|
||||
};
|
||||
|
||||
|
||||
res.set_content(response.dump(), "application/json");
|
||||
|
||||
|
||||
} catch (const json::exception& e) {
|
||||
res.status = 400;
|
||||
res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(),
|
||||
"application/json");
|
||||
res.set_content(
|
||||
json{{"error",
|
||||
{{"message", std::string("JSON parsing error: ") + e.what()}}}}
|
||||
.dump(),
|
||||
"application/json");
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(),
|
||||
"application/json");
|
||||
res.set_content(
|
||||
json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}
|
||||
.dump(),
|
||||
"application/json");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -280,11 +282,11 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques
|
|||
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
json request = json::parse(req.body);
|
||||
|
||||
|
||||
// Get or create session
|
||||
std::string session_id = request.value("sessionId", GenerateSessionId());
|
||||
auto& session = state.GetOrCreateSession(session_id);
|
||||
|
||||
|
||||
// Extract prompt from API format
|
||||
std::string prompt;
|
||||
if (request.contains("contents")) {
|
||||
|
|
@ -294,13 +296,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Set up SSE headers
|
||||
res.set_header("Content-Type", "text/event-stream");
|
||||
res.set_header("Cache-Control", "no-cache");
|
||||
res.set_header("Connection", "keep-alive");
|
||||
res.set_header("X-Session-Id", session_id);
|
||||
|
||||
|
||||
// Set up chunked content provider for SSE
|
||||
res.set_chunked_content_provider(
|
||||
"text/event-stream",
|
||||
|
|
@ -309,18 +311,15 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
// Lock for inference
|
||||
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
||||
auto& session = state.GetOrCreateSession(session_id);
|
||||
|
||||
|
||||
// Set up runtime config
|
||||
std::mt19937 gen;
|
||||
RuntimeConfig runtime_config = ParseGenerationConfig(request, gen);
|
||||
|
||||
RuntimeConfig runtime_config = ParseGenerationConfig(request);
|
||||
|
||||
// Tokenize prompt
|
||||
std::vector<int> tokens = WrapAndTokenize(state.gemma->Tokenizer(),
|
||||
state.gemma->ChatTemplate(),
|
||||
state.gemma->Config().wrapping,
|
||||
session.abs_pos,
|
||||
prompt);
|
||||
|
||||
std::vector<int> tokens = WrapAndTokenize(
|
||||
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
|
||||
state.gemma->Config().wrapping, session.abs_pos, prompt);
|
||||
|
||||
// Stream token callback
|
||||
std::string accumulated_text;
|
||||
auto stream_token = [&](int token, float) {
|
||||
|
|
@ -329,37 +328,38 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
session.abs_pos++;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
session.abs_pos++;
|
||||
|
||||
|
||||
// Check for EOS
|
||||
if (state.gemma->Config().IsEOS(token)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// Decode token
|
||||
std::string token_text;
|
||||
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
||||
accumulated_text += token_text;
|
||||
|
||||
|
||||
// Send SSE event using unified formatter
|
||||
json event = CreateAPIResponse(token_text, true);
|
||||
|
||||
|
||||
std::string sse_data = "data: " + event.dump() + "\n\n";
|
||||
sink.write(sse_data.data(), sse_data.size());
|
||||
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
runtime_config.stream_token = stream_token;
|
||||
|
||||
|
||||
// Run inference with KV cache
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
size_t prefix_end = 0;
|
||||
|
||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
||||
*session.kv_cache, *state.env, timing_info);
|
||||
|
||||
|
||||
state.gemma->Generate(runtime_config, tokens, session.abs_pos,
|
||||
prefix_end, *session.kv_cache, *state.env,
|
||||
timing_info);
|
||||
|
||||
// Send final event using unified formatter
|
||||
json final_event = CreateAPIResponse("", false);
|
||||
final_event["usageMetadata"] = {
|
||||
|
|
@ -367,18 +367,18 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
||||
{"totalTokenCount", session.abs_pos}
|
||||
};
|
||||
|
||||
|
||||
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
||||
sink.write(final_sse.data(), final_sse.size());
|
||||
|
||||
|
||||
// Send done event
|
||||
sink.write("data: [DONE]\n\n", 15);
|
||||
|
||||
|
||||
// Ensure all data is sent
|
||||
sink.done();
|
||||
|
||||
|
||||
return false; // End streaming
|
||||
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
json error_event = {{"error", {{"message", e.what()}}}};
|
||||
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
||||
|
|
@ -387,11 +387,14 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request&
|
|||
}
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
} catch (const json::exception& e) {
|
||||
res.status = 400;
|
||||
res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(),
|
||||
"application/json");
|
||||
res.set_content(
|
||||
json{{"error",
|
||||
{{"message", std::string("JSON parsing error: ") + e.what()}}}}
|
||||
.dump(),
|
||||
"application/json");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -410,7 +413,7 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
|||
{"topK", 1}
|
||||
}}}
|
||||
};
|
||||
|
||||
|
||||
res.set_content(response.dump(), "application/json");
|
||||
}
|
||||
|
||||
|
|
@ -419,40 +422,40 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const
|
|||
// server_running = false;
|
||||
// }
|
||||
|
||||
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
std::cerr << "Loading model..." << std::endl;
|
||||
|
||||
|
||||
// Initialize model
|
||||
ThreadingContext ctx(threading);
|
||||
MatMulEnv env(ctx);
|
||||
|
||||
|
||||
ServerState state;
|
||||
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
||||
state.env = &env;
|
||||
state.ctx = &ctx;
|
||||
|
||||
|
||||
httplib::Server server;
|
||||
|
||||
|
||||
// Set up routes
|
||||
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
||||
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
||||
});
|
||||
|
||||
|
||||
// API endpoints
|
||||
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleListModels(state, inference, req, res);
|
||||
});
|
||||
|
||||
|
||||
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
||||
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleGenerateContentNonStreaming(state, req, res);
|
||||
});
|
||||
|
||||
|
||||
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
||||
HandleGenerateContentStreaming(state, req, res);
|
||||
});
|
||||
|
||||
|
||||
// Periodic cleanup of old sessions
|
||||
std::thread cleanup_thread([&state]() {
|
||||
while (server_running) {
|
||||
|
|
@ -460,18 +463,18 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
state.CleanupOldSessions();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
||||
std::cerr << "Model loaded successfully" << std::endl;
|
||||
std::cerr << "Endpoints:" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
||||
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
||||
std::cerr << " GET /v1beta/models" << std::endl;
|
||||
|
||||
|
||||
if (!server.listen("0.0.0.0", inference.port)) {
|
||||
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
||||
}
|
||||
|
||||
|
||||
cleanup_thread.join();
|
||||
}
|
||||
|
||||
|
|
@ -479,11 +482,11 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::InternalInit();
|
||||
|
||||
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
||||
std::cout << "========================\n\n";
|
||||
|
|
@ -501,14 +504,14 @@ int main(int argc, char** argv) {
|
|||
std::cerr << "\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
// Arguments are now handled by InferenceArgs
|
||||
|
||||
|
||||
// // Set up signal handler
|
||||
// signal(SIGINT, gcpp::HandleShutdown);
|
||||
// signal(SIGTERM, gcpp::HandleShutdown);
|
||||
|
||||
|
||||
gcpp::RunServer(loader, threading, inference);
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -155,8 +155,9 @@ void SingleDotSoftmaxWeightedSum(
|
|||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
|
||||
Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
|
||||
const Logits logits(att, att_len);
|
||||
MaybeLogitsSoftCap(att_cap, logits, p, worker);
|
||||
Softmax(logits, p, worker, /*temperature=*/1.0f);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
||||
worker);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h" // InitGenerator
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
|
|
@ -135,8 +134,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
std::stringstream ss;
|
||||
result_buffer.clear();
|
||||
|
||||
InitGenerator(inference_args, gen);
|
||||
|
||||
// Ensure we have an active conversation
|
||||
if (!active_conversation || !active_conversation->kv_cache) {
|
||||
LogDebug("Generate called with null active_conversation or kv_cache");
|
||||
|
|
@ -174,8 +171,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
|
||||
// set up runtime config
|
||||
TimingInfo timing_info = {};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
RuntimeConfig runtime_config = {.stream_token = stream_token,
|
||||
.use_spinning = threading_args.spin};
|
||||
inference_args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
|
@ -256,7 +252,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
active_conversation->abs_pos = 0;
|
||||
InitGenerator(inference_args, gen);
|
||||
} else {
|
||||
// Multi-turn Gemma: Rewind position in the active conversation
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
|
||||
#include <memory> // For std::shared_ptr, std::make_shared
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
|
@ -107,10 +106,6 @@ class GemmaContext {
|
|||
// Set deterministic flag
|
||||
void SetDeterministic(bool value) {
|
||||
inference_args.deterministic = value;
|
||||
// Reset the random number generator for deterministic generation
|
||||
if (value) {
|
||||
gen.seed(0x87654321);
|
||||
}
|
||||
LogDebug("Setting deterministic flag to configured value");
|
||||
}
|
||||
|
||||
|
|
@ -289,9 +284,6 @@ class GemmaContext {
|
|||
// Model itself (don't move this, needs to be below the args above)
|
||||
Gemma model;
|
||||
|
||||
// Random generator (remains global for the context)
|
||||
std::mt19937 gen;
|
||||
|
||||
// Static members for logging
|
||||
static GemmaLogCallback s_log_callback;
|
||||
static void* s_log_user_data;
|
||||
|
|
|
|||
|
|
@ -440,8 +440,7 @@ static void SampleAndStream(
|
|||
|
||||
// TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi));
|
||||
|
||||
// We streamed all prefill tokens, but pos is still one behind because we
|
||||
// started generation at pos = prompt.size() - 1. We want the pos argument
|
||||
|
|
@ -453,7 +452,8 @@ static void SampleAndStream(
|
|||
}
|
||||
|
||||
static HWY_INLINE SampleFunc
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, ThreadingContext& ctx) {
|
||||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
|
|
@ -462,27 +462,28 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
|||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
return Top1OfSoftmax(logits, vocab_size);
|
||||
return Top1OfSoftmax(logits);
|
||||
};
|
||||
}
|
||||
|
||||
// General case: Softmax with top-k sampling.
|
||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
RngStream gen(engine, qi);
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
|
||||
worker);
|
||||
logits, runtime_config.top_k, gen, runtime_config.temperature,
|
||||
runtime_config.accept_token, ctx.profiler, worker);
|
||||
};
|
||||
}
|
||||
|
||||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
if (qbatch.MutablePos(qi) == 0) {
|
||||
|
|
@ -554,7 +555,8 @@ static void GenerateT(const ModelConfig& config,
|
|||
max_gen_steps = seq_len - max_prompt_size;
|
||||
}
|
||||
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
|
||||
const SampleFunc sample_token =
|
||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
|
|
@ -568,15 +570,16 @@ static void GenerateT(const ModelConfig& config,
|
|||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, KVCache& kv_cache,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||
|
||||
AllQueries all_queries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCache>(&kv_cache, 1));
|
||||
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
|
||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
|
|
@ -584,8 +587,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
|||
// queries, and calls `GenerateT` on each batch.
|
||||
void GenerateBatchT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, AllQueries& all_queries,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
Activations activations(config, max_batch_size,
|
||||
|
|
@ -596,7 +600,7 @@ void GenerateBatchT(const ModelConfig& config,
|
|||
start += runtime_config.decode_qbatch_size) {
|
||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
}
|
||||
}
|
||||
|
|
@ -637,7 +641,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(inference) {
|
||||
inference_(inference),
|
||||
aes_ctr_engine_(inference.deterministic) {
|
||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
||||
mat_owners_, ctx);
|
||||
|
|
@ -661,9 +666,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
TimingInfo& timing_info) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
|
||||
model_.Config(), runtime_config,
|
||||
weights_, kv_cache, env, timing_info);
|
||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(
|
||||
prompt, pos, prefix_end, model_.Config(), runtime_config, aes_ctr_engine_,
|
||||
weights_, kv_cache, env, timing_info);
|
||||
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
@ -674,7 +679,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
||||
weights_, all_queries, env, timing_info);
|
||||
aes_ctr_engine_, weights_, all_queries,
|
||||
env, timing_info);
|
||||
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -278,6 +278,7 @@ class Gemma {
|
|||
WeightsPtrs::Mode weight_read_mode_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
InferenceArgs inference_;
|
||||
AesCtrEngine aes_ctr_engine_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "io/io.h" // Path
|
||||
|
|
@ -90,10 +89,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
|||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
// If not empty, SampleFunc is called with the logits for the next token, which
|
||||
// it may modify/overwrite, and its return value is the next generated token
|
||||
// together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
||||
// If not empty, SampleFunc is called with the query_idx and logits for the
|
||||
// next token, which it may modify/overwrite. It returns the next generated
|
||||
// token together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(size_t, Logits)>;
|
||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
|
|
@ -136,8 +135,7 @@ struct RuntimeConfig {
|
|||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
|
||||
size_t top_k = 1; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
size_t top_k = 1; // Top-k for sampling.
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
||||
|
|
|
|||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -18,7 +18,6 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
|
@ -98,9 +97,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
size_t prompt_size = 0;
|
||||
const ModelConfig& config = gemma.Config();
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(inference, gen);
|
||||
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
const size_t pool_dim = config.vit_config.pool_dim;
|
||||
|
|
@ -117,8 +113,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
|
|
@ -188,8 +183,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
|
||||
.batch_stream_token = batch_stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
|
|
@ -239,7 +233,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(inference, gen);
|
||||
} else {
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
|
|
|
|||
|
|
@ -110,8 +110,7 @@ class VitAttention {
|
|||
CallMatMul(Q, K, nullptr, env_, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
Softmax(c, C.Cols(), env_.ctx.profiler, worker);
|
||||
Softmax(C.RowSpan(task), env_.ctx.profiler, worker);
|
||||
});
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
|
|
@ -154,7 +153,7 @@ class VitAttention {
|
|||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len, env_.ctx.profiler, worker);
|
||||
Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
|
|
|
|||
|
|
@ -812,7 +812,7 @@ class DotStats {
|
|||
|
||||
// Forward relative error, lower is better.
|
||||
void CheckRel() const {
|
||||
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 4E-3);
|
||||
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 7E-3);
|
||||
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f);
|
||||
|
||||
// Compensated and Double are very accurate.
|
||||
|
|
@ -822,22 +822,22 @@ class DotStats {
|
|||
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
||||
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2);
|
||||
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 3.5E-1);
|
||||
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(),
|
||||
0.072);
|
||||
7.5E-2);
|
||||
|
||||
// Kahan (FastTwoSum) is decent:
|
||||
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3);
|
||||
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 1E-2);
|
||||
ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f);
|
||||
|
||||
// TwoProducts and TwoSums are a bit better.
|
||||
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(),
|
||||
3E-3);
|
||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f);
|
||||
1.1E-2);
|
||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 1.0f);
|
||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(),
|
||||
2.6E-3);
|
||||
1.1E-2);
|
||||
|
||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
|
||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 5.2E-2);
|
||||
// Extremely high error on aarch64.
|
||||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
|
||||
}
|
||||
|
|
@ -857,7 +857,7 @@ class DotStats {
|
|||
ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f);
|
||||
|
||||
// But TwoProducts/TwoSums help a bit.
|
||||
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f);
|
||||
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 1.0f);
|
||||
ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f);
|
||||
|
||||
// Extremely high error on aarch64.
|
||||
|
|
@ -893,7 +893,7 @@ class DotStats {
|
|||
};
|
||||
|
||||
// Returns normalized value in [-1, 1).
|
||||
float RandomFloat(std::mt19937& rng) {
|
||||
float RandomFloat(RngStream& rng) {
|
||||
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
|
||||
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
|
||||
const uint32_t representation = exp | (rng() & mantissa_mask);
|
||||
|
|
@ -908,7 +908,7 @@ float RandomFloat(std::mt19937& rng) {
|
|||
// error from the Dot algorithms, not the compression.
|
||||
template <typename Packed>
|
||||
void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
||||
std::mt19937& rng,
|
||||
RngStream& rng,
|
||||
const PackedSpan<Packed>& packed,
|
||||
CompressWorkingSet& work) {
|
||||
std::uniform_int_distribution<int> e_dist(0, 6);
|
||||
|
|
@ -934,7 +934,7 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
|||
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
|
||||
template <typename WT, typename VT>
|
||||
double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
|
||||
std::mt19937& rng) {
|
||||
RngStream& rng) {
|
||||
PROFILER_FUNC;
|
||||
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
|
||||
HWY_DASSERT(half != 0);
|
||||
|
|
@ -1002,8 +1002,8 @@ struct TestShortDotsT {
|
|||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
CompressWorkingSet work;
|
||||
std::mt19937 rng;
|
||||
rng.seed(12345);
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
RngStream rng(engine, 0);
|
||||
|
||||
hwy::Stats s_l1[kVariants];
|
||||
|
||||
|
|
@ -1108,9 +1108,10 @@ void TestAllDot() {
|
|||
{ // ensure no profiler zones are active
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
std::mt19937 rngs[kMaxWorkers];
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
RngStream rngs[kMaxWorkers];
|
||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||
rngs[i].seed(12345 + 65537 * i);
|
||||
rngs[i] = RngStream(engine, i);
|
||||
}
|
||||
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
|
|
|
|||
150
ops/ops-inl.h
150
ops/ops-inl.h
|
|
@ -29,7 +29,7 @@
|
|||
|
||||
#include "ops/matmul.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
#include "util/basics.h" // TokenAndProb, RngStream
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -614,12 +614,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
hwy::Profiler& p, const size_t worker,
|
||||
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||
const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
static const auto zone = p.AddZone("Ops.Softmax");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(logits.size() != 0);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -629,24 +629,25 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V vmax = vmin;
|
||||
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||
hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
*pmax = hn::Max(*pmax, value);
|
||||
});
|
||||
hn::Foreach(d, logits.data(), logits.size(), vmin,
|
||||
[pmax](const auto d, const V value)
|
||||
HWY_ATTR { *pmax = hn::Max(*pmax, value); });
|
||||
vmax = hn::MaxOfLanes(d, vmax);
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||
} else {
|
||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||
}
|
||||
});
|
||||
hn::Transform(d, logits.data(), logits.size(),
|
||||
[pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Workaround for buggy SVE codegen: avoid inlined Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||
} else {
|
||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||
}
|
||||
});
|
||||
|
||||
if (temperature != 1.0f) {
|
||||
const float temperature_inv = 1.0f / temperature;
|
||||
hn::Transform(d, x, size,
|
||||
hn::Transform(d, logits.data(), logits.size(),
|
||||
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
||||
return hn::Mul(value, hn::Set(d, temperature_inv));
|
||||
});
|
||||
|
|
@ -656,10 +657,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
// not make a huge difference. It halves the standard deviation of the sum of
|
||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||
// the generated text after a few hundred tokens.
|
||||
const float sum_exp = Sum(d, x, size);
|
||||
const float sum_exp = Sum(d, logits.data(), logits.size());
|
||||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, x, size, p, worker);
|
||||
MulByConst(mul, logits.data(), logits.size(), p, worker);
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
@ -669,8 +670,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
// which already knows the max value which top-1 sampling would again seek.
|
||||
|
||||
// Returns the argmax and x[argmax].
|
||||
static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
||||
const size_t num) {
|
||||
static HWY_INLINE TokenAndProb ArgmaxAndMax(Logits logits) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
|
|
@ -680,16 +680,16 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
|||
using TI = hn::TFromD<decltype(di)>;
|
||||
using VI = hn::Vec<decltype(di)>;
|
||||
const size_t N = hn::Lanes(d);
|
||||
HWY_ASSERT(num % (2 * N) == 0);
|
||||
HWY_ASSERT(logits.size() % (2 * N) == 0);
|
||||
|
||||
V max0 = hn::Set(d, hwy::LowestValue<float>());
|
||||
V max1 = max0;
|
||||
VI argmax0 = hn::Zero(di);
|
||||
VI argmax1 = argmax0;
|
||||
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
const V v0 = hn::LoadU(d, x + i);
|
||||
const V v1 = hn::LoadU(d, x + i + N);
|
||||
for (size_t i = 0; i < logits.size(); i += 2 * N) {
|
||||
const V v0 = hn::LoadU(d, &logits[i]);
|
||||
const V v1 = hn::LoadU(d, &logits[i + N]);
|
||||
const VI vi0 = hn::Iota(di, static_cast<TI>(i));
|
||||
const VI vi1 = hn::Iota(di, static_cast<TI>(i + N));
|
||||
const M gt0 = hn::Gt(v0, max0);
|
||||
|
|
@ -714,43 +714,43 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x,
|
|||
return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)};
|
||||
}
|
||||
|
||||
// Returns argmax of softmax and its probability. This overwrites `x`, but not
|
||||
// with normalized probabilities. Only equivalent to `Softmax` + `sample_func`
|
||||
// if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize`
|
||||
// == 256K, and this avoids writing and then scanning again for the max.
|
||||
// However, this is not enough to make parallelization worthwhile.
|
||||
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||
const size_t num) {
|
||||
// Returns argmax of softmax and its probability. This overwrites `logits`, but
|
||||
// not with normalized probabilities. Only equivalent to `Softmax` +
|
||||
// `sample_func` if `kTopK` == 1. This is worthwhile because `logits.size()` is
|
||||
// typically `kVocabSize == 256K`, and this avoids writing and then scanning
|
||||
// again for the max.
|
||||
static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> d;
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
|
||||
const TokenAndProb argmax = ArgmaxAndMax(x, num);
|
||||
const TokenAndProb argmax = ArgmaxAndMax(logits);
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
const V max = hn::Set(d, argmax.prob);
|
||||
const V* pmax = &max;
|
||||
hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||
} else {
|
||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||
}
|
||||
});
|
||||
hn::Transform(d, logits.data(), logits.size(),
|
||||
[pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Temporary workaround for buggy SVE codegen: avoid inlined
|
||||
// Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||
} else {
|
||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||
}
|
||||
});
|
||||
|
||||
// Normalize to a single probability. The exact sum seems like it should not
|
||||
// make a huge difference. It halves the standard deviation of the sum of the
|
||||
// normalized probabilities from 1E-7 to 5E-8, but actually also changes the
|
||||
// generated text after a few hundred tokens.
|
||||
const float sum_exp = Sum(d, x, num);
|
||||
const float prob = x[argmax.token] / sum_exp;
|
||||
const float sum_exp = Sum(d, logits.data(), logits.size());
|
||||
const float prob = logits[argmax.token] / sum_exp;
|
||||
return TokenAndProb{.token = argmax.token, .prob = prob};
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
|
|
@ -763,18 +763,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
const VF* HWY_RESTRICT pcap = &vcap;
|
||||
const VF* HWY_RESTRICT pinv_cap = &vinv_cap;
|
||||
|
||||
DecompressAndCompressInplace(
|
||||
DF(), x, size, [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF {
|
||||
return hn::Mul(*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap)));
|
||||
});
|
||||
DecompressAndCompressInplace(DF(), logits.data(), logits.size(),
|
||||
[pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF {
|
||||
return hn::Mul(
|
||||
*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
// Calls LogitsSoftCap if cap != 0.0f.
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
const float cap, Logits logits, hwy::Profiler& p, const size_t worker) {
|
||||
if (cap != 0.0f) {
|
||||
LogitsSoftCap(cap, x, size, p, worker);
|
||||
LogitsSoftCap(cap, logits, p, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -785,20 +785,18 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
|||
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
|
||||
worker);
|
||||
LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(Logits logits) {
|
||||
size_t max_index = 0;
|
||||
float max_prob = probabilities[0];
|
||||
for (size_t i = 1; i < vocab_size; ++i) {
|
||||
if (probabilities[i] > max_prob) {
|
||||
float max_prob = logits[0];
|
||||
for (size_t i = 1; i < logits.size(); ++i) {
|
||||
if (logits[i] > max_prob) {
|
||||
max_index = i;
|
||||
max_prob = probabilities[i];
|
||||
max_prob = logits[i];
|
||||
}
|
||||
}
|
||||
return max_index;
|
||||
|
|
@ -828,16 +826,15 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
|
|||
|
||||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
||||
const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k,
|
||||
TAcceptToken& accept_token) {
|
||||
Logits logits, size_t k, TAcceptToken& accept_token) {
|
||||
HWY_ASSERT(k != 0);
|
||||
HWY_ASSERT(k <= vocab_size);
|
||||
HWY_ASSERT(k <= logits.size());
|
||||
std::vector<double> packed_token_probs;
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(vocab_size); ++i) {
|
||||
if (accept_token && !accept_token(i, probabilities[i])) {
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(logits.size()); ++i) {
|
||||
if (accept_token && !accept_token(i, logits[i])) {
|
||||
continue;
|
||||
}
|
||||
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
|
||||
packed_token_probs.push_back(PackTokenAndProb(i, logits[i]));
|
||||
}
|
||||
|
||||
hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
|
||||
|
|
@ -853,11 +850,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
|
|||
}
|
||||
|
||||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||
std::vector<TokenAndProb> token_probs =
|
||||
TopK(probabilities, vocab_size, k, accept_token);
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k,
|
||||
RngStream& gen, float temperature,
|
||||
TAcceptToken& accept_token) {
|
||||
std::vector<TokenAndProb> token_probs = TopK(logits, k, accept_token);
|
||||
std::vector<int> topk_indices(k);
|
||||
std::vector<float> topk_probs(k);
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
|
|
@ -869,14 +865,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
|
||||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
||||
hwy::Profiler& p, size_t worker) {
|
||||
Logits logits, size_t k, RngStream& gen, float temperature,
|
||||
TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) {
|
||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||
// avoids computing the softmax of all logits.
|
||||
std::vector<TokenAndProb> token_logits =
|
||||
TopK(logits, vocab_size, k, accept_token);
|
||||
std::vector<TokenAndProb> token_logits = TopK(logits, k, accept_token);
|
||||
std::vector<int> topk_indices(k);
|
||||
std::vector<float> topk_logits(k);
|
||||
for (size_t i = 0; i < token_logits.size(); ++i) {
|
||||
|
|
@ -884,8 +878,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
topk_logits[i] = token_logits[i].prob;
|
||||
}
|
||||
|
||||
size_t mask = token_logits.size();
|
||||
Softmax(topk_logits.data(), mask, p, worker, temperature);
|
||||
const size_t mask = token_logits.size();
|
||||
Softmax(Logits(topk_logits.data(), mask), p, worker, temperature);
|
||||
auto distribution = std::discrete_distribution<int>(
|
||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||
int topk_sampled_index = distribution(gen);
|
||||
|
|
|
|||
|
|
@ -57,6 +57,12 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
static RngStream MakeRng() {
|
||||
static AesCtrEngine engine(/*deterministic=*/true);
|
||||
static uint64_t stream = 0;
|
||||
return RngStream(engine, ++stream);
|
||||
}
|
||||
|
||||
template <class Test>
|
||||
struct ForeachCountAndMisalign {
|
||||
template <typename T, class D>
|
||||
|
|
@ -304,7 +310,7 @@ class TestSoftmax {
|
|||
}
|
||||
|
||||
SimpleSoftmax(e, count);
|
||||
Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
|
||||
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
@ -438,10 +444,9 @@ void TestRopeAndMulBy() {
|
|||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
MatStorageT<float> x("x", dim_qkv, ctx.allocator);
|
||||
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
RngStream rng = MakeRng();
|
||||
std::normal_distribution<float> r{0.0, 5.0};
|
||||
auto random_float = [&r, &gen] { return r(gen); };
|
||||
auto random_float = [&r, &rng] { return r(rng); };
|
||||
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
x.Row(0)[i] = random_float();
|
||||
|
|
@ -704,38 +709,34 @@ void TestSampleTopK() {
|
|||
hwy::Profiler& p = hwy::Profiler::Get();
|
||||
const size_t worker = 0;
|
||||
const size_t kSize = 52;
|
||||
std::vector<float> logits(kSize);
|
||||
std::vector<float> logits_vec(kSize);
|
||||
Logits logits(logits_vec.data(), kSize);
|
||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||
Softmax(logits.data(), kSize, p, worker);
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
Softmax(logits, p, worker);
|
||||
RngStream rng = MakeRng();
|
||||
float temperature = 1.0f;
|
||||
// SampleTopK<1> should return the argmax.
|
||||
std::function<bool(int, float)> accept_token;
|
||||
int sample =
|
||||
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
|
||||
int sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token);
|
||||
EXPECT_EQ(sample, 51); // Last is largest.
|
||||
// Only accept even tokens, expect the last (largest) even index.
|
||||
accept_token = [](int i, float) { return i % 2 == 0; };
|
||||
sample =
|
||||
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
|
||||
sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token);
|
||||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
Softmax(logits.data(), kSize, p, worker);
|
||||
Softmax(logits, p, worker);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
accept_token);
|
||||
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
|
||||
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
|
||||
}
|
||||
// Now set the temperature to 0.0f, which should always return the argmax,
|
||||
// even for k=3.
|
||||
temperature = 0.0f;
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
accept_token);
|
||||
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
|
||||
EXPECT_EQ(sample, 50);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,42 +27,38 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
|
|||
HWY_ASSERT(image.ReadPPM(path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
|
||||
.verbosity = 0};
|
||||
RuntimeConfig runtime_config = {.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
|
||||
image, *image_tokens_, env_->MutableEnv());
|
||||
}
|
||||
|
||||
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
||||
const Gemma& model = *(env_->GetGemma());
|
||||
env_->MutableGen().seed(0x12345678);
|
||||
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
// PrefixLM sees/attends to all tokens.
|
||||
.prefill_tbatch_size = tokens.size(),
|
||||
.gen = &env_->MutableGen(),
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.image_tokens = image_tokens_.get()};
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
// PrefixLM sees/attends to all tokens.
|
||||
.prefill_tbatch_size = tokens.size(),
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.image_tokens = image_tokens_.get()};
|
||||
|
||||
const size_t prefix_end = tokens.size();
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
|
||||
return response;
|
||||
const size_t prefix_end = tokens.size();
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -53,9 +53,8 @@ class GemmaModel {
|
|||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
|
||||
size_t max_generated_tokens, float temperature, float seed,
|
||||
gcpp::AcceptFunc accept, bool skip_prompt) {
|
||||
env_.MutableGen().seed(seed);
|
||||
size_t max_generated_tokens, float temperature,
|
||||
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
|
||||
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
|
|
@ -77,7 +76,7 @@ class GemmaModel {
|
|||
|
||||
// Generates a single example, given a prompt, and returns the result.
|
||||
std::string Generate(std::string prompt, size_t max_generated_tokens,
|
||||
float temperature, float seed,
|
||||
float temperature, float /*seed*/,
|
||||
const std::vector<std::string>& accept,
|
||||
const std::vector<std::string>& end) {
|
||||
std::set<int> end_token_set{};
|
||||
|
|
@ -124,7 +123,6 @@ class GemmaModel {
|
|||
}
|
||||
};
|
||||
|
||||
env_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
|
|
@ -144,14 +142,13 @@ class GemmaModel {
|
|||
// results.
|
||||
std::vector<std::string> GenerateBatch(const std::vector<std::string>& inputs,
|
||||
size_t max_generated_tokens,
|
||||
float temperature, float seed,
|
||||
float temperature, float /*seed*/,
|
||||
size_t top_k) {
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.top_k = top_k;
|
||||
config.verbosity = 0;
|
||||
env_.MutableGen().seed(seed);
|
||||
|
||||
std::vector<gcpp::QueryResult> outputs = env_.BatchQueryModel(inputs);
|
||||
std::vector<std::string> result;
|
||||
|
|
@ -187,8 +184,7 @@ class GemmaModel {
|
|||
"image_tokens",
|
||||
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
|
||||
.verbosity = 0};
|
||||
gcpp::RuntimeConfig runtime_config = {.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
|
||||
c_image, *image_tokens_, env_.MutableEnv());
|
||||
}
|
||||
|
|
@ -197,10 +193,9 @@ class GemmaModel {
|
|||
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
|
||||
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
||||
std::string prompt, size_t max_generated_tokens, float temperature,
|
||||
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
||||
const gcpp::Gemma& model = *env_.GetGemma();
|
||||
env_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = env_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
|
|
@ -273,6 +268,7 @@ PYBIND11_MODULE(gemma, mod) {
|
|||
}),
|
||||
py::arg("tokenizer_path"), py::arg("weights_path"),
|
||||
py::arg("max_threads") = 0)
|
||||
// seed arguments are ignored.
|
||||
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
|
||||
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
|
||||
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
RNG::RNG(bool deterministic) {
|
||||
AesCtrEngine::AesCtrEngine(bool deterministic) {
|
||||
// Pi-based nothing up my sleeve numbers from Randen.
|
||||
key_[0] = 0x243F6A8885A308D3ull;
|
||||
key_[1] = 0x13198A2E03707344ull;
|
||||
|
|
@ -54,9 +54,10 @@ static V Load(const uint64_t* ptr) {
|
|||
return hn::Load(D(), reinterpret_cast<const uint8_t*>(ptr));
|
||||
}
|
||||
|
||||
RNG::result_type RNG::operator()() {
|
||||
V state = Load(counter_);
|
||||
counter_[0]++;
|
||||
uint64_t AesCtrEngine::operator()(uint64_t stream, uint64_t counter) const {
|
||||
const hn::Repartition<uint64_t, D> d64;
|
||||
|
||||
V state = hn::BitCast(D(), hn::Dup128VecFromValues(d64, counter, stream));
|
||||
state = hn::Xor(state, Load(key_)); // initial whitening
|
||||
|
||||
static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t));
|
||||
|
|
@ -68,7 +69,6 @@ RNG::result_type RNG::operator()() {
|
|||
state = hn::AESRound(state, Load(key_ + 10));
|
||||
|
||||
// Return lower 64 bits of the u8 vector.
|
||||
const hn::Repartition<uint64_t, D> d64;
|
||||
return hn::GetLane(hn::BitCast(d64, state));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
|
|
@ -120,39 +120,60 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end,
|
|||
return IndexRange(begin, HWY_MIN(begin + max_size, end));
|
||||
}
|
||||
|
||||
using Logits = hwy::Span<float>; // size() is vocab_size.
|
||||
|
||||
// Non-cryptographic 64-bit pseudo-random number generator. Supports random or
|
||||
// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`.
|
||||
// deterministic seeding.
|
||||
//
|
||||
// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This
|
||||
// is useful for parallel sampling. Each thread can generate the stream for a
|
||||
// particular task, without caring about prior/subsequent generations.
|
||||
class alignas(16) RNG {
|
||||
class alignas(16) AesCtrEngine {
|
||||
// "Large-scale randomness study of security margins for 100+ cryptographic
|
||||
// functions": at least four.
|
||||
// "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant.
|
||||
static constexpr size_t kRounds = 5;
|
||||
|
||||
public:
|
||||
explicit RNG(bool deterministic);
|
||||
// If `deterministic` is true, uses a fixed seed; otherwise, attempts to
|
||||
// grab entropy from the OS.
|
||||
explicit AesCtrEngine(bool deterministic);
|
||||
|
||||
void SetStream(uint64_t stream) {
|
||||
counter_[1] = stream;
|
||||
counter_[0] = 0;
|
||||
}
|
||||
// Pure and thread safe; typically called via `RngStream`, which increments
|
||||
// `counter`. Throughput is about 100M/s on 3 GHz Skylake. It could be
|
||||
// increased 4x via unrolling by the AES latency (4-7 cycles), but because
|
||||
// users generally call once at a time, this requires buffering, which is not
|
||||
// worth the complexity in this application.
|
||||
uint64_t operator()(uint64_t stream, uint64_t counter) const;
|
||||
|
||||
private:
|
||||
uint64_t key_[2 * (1 + kRounds)];
|
||||
};
|
||||
|
||||
// Flyweight per-thread adapter that maintains the counter. Conforms to C++
|
||||
// `UniformRandomBitGenerator`.
|
||||
class RngStream {
|
||||
public:
|
||||
RngStream() = default; // Allow C arrays with subsequent initialization.
|
||||
|
||||
// Binds to an engine, which holds the seed and must outlive this object.
|
||||
// Sets the stream; any other `RngStream` with the same `counter_rng` and
|
||||
// `stream` will return the same sequence. This is typically the task ID, so
|
||||
// that threads can independently generate values for each task.
|
||||
RngStream(const AesCtrEngine& counter_rng, uint64_t stream)
|
||||
: engine_(&counter_rng), stream_(stream), counter_(0) {}
|
||||
|
||||
using result_type = uint64_t;
|
||||
static constexpr result_type min() { return 0; }
|
||||
static constexpr result_type max() { return ~result_type{0}; }
|
||||
|
||||
// About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via
|
||||
// unrolling by the AES latency (4-7 cycles). `std::discrete_distribution`
|
||||
// makes individual calls to the generator, which would require buffering,
|
||||
// which is not worth the complexity.
|
||||
result_type operator()();
|
||||
result_type operator()() { return (*engine_)(stream_, counter_++); }
|
||||
|
||||
private:
|
||||
uint64_t counter_[2] = {};
|
||||
uint64_t key_[2 * (1 + kRounds)];
|
||||
const AesCtrEngine* engine_ = nullptr;
|
||||
uint64_t stream_ = 0; // immutable after ctor
|
||||
uint64_t counter_ = 0;
|
||||
// Prevent false sharing if used by multiple threads.
|
||||
HWY_MAYBE_UNUSED uint8_t padding_[HWY_ALIGNMENT - 16 - sizeof(engine_)];
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -25,9 +25,11 @@
|
|||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
TEST(BasicsTest, IsDeterministic) {
|
||||
RNG rng1(/*deterministic=*/true);
|
||||
RNG rng2(/*deterministic=*/true);
|
||||
TEST(BasicsTest, EngineIsDeterministic) {
|
||||
const AesCtrEngine engine1(/*deterministic=*/true);
|
||||
const AesCtrEngine engine2(/*deterministic=*/true);
|
||||
RngStream rng1(engine1, 0);
|
||||
RngStream rng2(engine2, 0);
|
||||
// Remember for later testing after resetting the stream.
|
||||
const uint64_t r0 = rng1();
|
||||
const uint64_t r1 = rng1();
|
||||
|
|
@ -42,15 +44,17 @@ TEST(BasicsTest, IsDeterministic) {
|
|||
HWY_ASSERT(rng1() == rng2());
|
||||
}
|
||||
|
||||
// Reset counter, ensure it matches the default-constructed RNG.
|
||||
rng1.SetStream(0);
|
||||
// Reset counter, ensure it matches the prior sequence.
|
||||
rng1 = RngStream(engine1, 0);
|
||||
HWY_ASSERT(r0 == rng1());
|
||||
HWY_ASSERT(r1 == rng1());
|
||||
}
|
||||
|
||||
TEST(BasicsTest, IsSeeded) {
|
||||
RNG rng1(/*deterministic=*/true);
|
||||
RNG rng2(/*deterministic=*/false);
|
||||
TEST(BasicsTest, EngineIsSeeded) {
|
||||
AesCtrEngine engine1(/*deterministic=*/true);
|
||||
AesCtrEngine engine2(/*deterministic=*/false);
|
||||
RngStream rng1(engine1, 0);
|
||||
RngStream rng2(engine2, 0);
|
||||
// It would be very unlucky to have even one 64-bit value match, and two are
|
||||
// extremely unlikely.
|
||||
const uint64_t a0 = rng1();
|
||||
|
|
@ -60,9 +64,27 @@ TEST(BasicsTest, IsSeeded) {
|
|||
HWY_ASSERT(a0 != b0 || a1 != b1);
|
||||
}
|
||||
|
||||
TEST(BasicsTest, StreamsDiffer) {
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
// Compare random streams for more coverage than just the first N streams.
|
||||
RngStream rng_for_stream(engine, 0);
|
||||
for (size_t i = 0; i < 1000; ++i) {
|
||||
RngStream rng1(engine, rng_for_stream());
|
||||
RngStream rng2(engine, rng_for_stream());
|
||||
// It would be very unlucky to have even one 64-bit value match, and two are
|
||||
// extremely unlikely.
|
||||
const uint64_t a0 = rng1();
|
||||
const uint64_t a1 = rng1();
|
||||
const uint64_t b0 = rng2();
|
||||
const uint64_t b1 = rng2();
|
||||
HWY_ASSERT(a0 != b0 || a1 != b1);
|
||||
}
|
||||
}
|
||||
|
||||
// If not close to 50% 1-bits, the RNG is quite broken.
|
||||
TEST(BasicsTest, BitDistribution) {
|
||||
RNG rng(/*deterministic=*/true);
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
RngStream rng(engine, 0);
|
||||
constexpr size_t kU64 = 2 * 1000 * 1000;
|
||||
const hwy::Timestamp t0;
|
||||
uint64_t one_bits = 0;
|
||||
|
|
@ -78,7 +100,8 @@ TEST(BasicsTest, BitDistribution) {
|
|||
}
|
||||
|
||||
TEST(BasicsTest, ChiSquared) {
|
||||
RNG rng(/*deterministic=*/true);
|
||||
AesCtrEngine engine(/*deterministic=*/true);
|
||||
RngStream rng(engine, 0);
|
||||
constexpr size_t kU64 = 1 * 1000 * 1000;
|
||||
|
||||
// Test each byte separately.
|
||||
|
|
|
|||
|
|
@ -301,6 +301,13 @@ class MatPtrT : public MatPtr {
|
|||
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
|
||||
}
|
||||
|
||||
hwy::Span<MatT> RowSpan(size_t row) {
|
||||
return hwy::Span<MatT>(Row(row), Cols());
|
||||
}
|
||||
hwy::Span<const MatT> RowSpan(size_t row) const {
|
||||
return hwy::Span<const MatT>(Row(row), Cols());
|
||||
}
|
||||
|
||||
PackedSpan<const MatT> PaddedSpan() const {
|
||||
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue