diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index df300f4..7ddee9e 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -24,7 +24,6 @@ #include // std::shuffle #include -#include #include "compression/distortion.h" #include "util/test_util.h" @@ -104,8 +103,8 @@ struct TestPlateaus { HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f); } - std::random_device rd; // NOLINT - std::mt19937 rng(rd()); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); std::shuffle(in.get(), in.get() + kGroupSize, rng); NuqStream::ClusterBuf buf; @@ -151,8 +150,8 @@ struct TestRamp { HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f); } - std::random_device rd; // NOLINT - std::mt19937 rng(rd()); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); std::shuffle(in.get(), in.get() + kGroupSize, rng); NuqStream::ClusterBuf buf; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 3b999b4..55e99cf 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -20,7 +20,6 @@ #include #include -#include #include #include @@ -37,17 +36,6 @@ namespace gcpp { -void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { - if (inference.deterministic) { - // Nothing up my sleeve number, at least some upper bits set. - gen.seed(0x12345678); - } else { - // Depending on the library implementation, this may still be deterministic. - std::random_device rd; // NOLINT - gen.seed(rd()); - } -} - GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { @@ -60,12 +48,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, ctx_); } - InitGenerator(inference, gen_); - runtime_config_ = { .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, - .gen = &gen_, .verbosity = inference.verbosity, }; inference.CopyTo(runtime_config_); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 8f4d96f..261daa4 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -18,7 +18,6 @@ #include -#include #include #include @@ -32,8 +31,6 @@ namespace gcpp { -void InitGenerator(const InferenceArgs& inference, std::mt19937& gen); - // Return type for query model calls. struct QueryResult { std::string response; @@ -107,7 +104,6 @@ class GemmaEnv { int Verbosity() const { return runtime_config_.verbosity; } RuntimeConfig& MutableConfig() { return runtime_config_; } - std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return kv_caches_[0]; } MatMulEnv& MutableEnv() { return env_; } @@ -115,7 +111,6 @@ class GemmaEnv { ThreadingContext ctx_; MatMulEnv env_; Gemma gemma_; - std::mt19937 gen_; // Random number generator. std::vector kv_caches_; // Same number as query batch. RuntimeConfig runtime_config_; }; diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index c150041..b7abb10 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -56,11 +56,10 @@ static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'"; } -void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len, - size_t k) { - std::vector> sorted(len); - for (size_t i = 0; i < len; ++i) { - sorted[i] = std::make_pair(dist[i], static_cast(i)); +void LogTopK(const GemmaTokenizer& tokenizer, Logits logits, size_t k) { + std::vector> sorted(logits.size()); + for (size_t i = 0; i < logits.size(); ++i) { + sorted[i] = std::make_pair(logits[i], static_cast(i)); } std::sort(sorted.begin(), sorted.end(), [](const std::pair& a, const std::pair& b) { @@ -84,9 +83,8 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size, - hwy::Profiler& p) { - Softmax(logits, vocab_size, p, hwy::Profiler::Thread()); +void CallSoftmax(Logits logits, hwy::Profiler& p) { + Softmax(logits, p, hwy::Profiler::Thread()); } } // namespace HWY_NAMESPACE @@ -107,19 +105,19 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) size_t pos = 1; - const SampleFunc sample_token = [&](float* probs, - size_t vocab_size) -> TokenAndProb { + const SampleFunc sample_token = [&](size_t qi, + Logits logits) -> TokenAndProb { // input is logits, not yet probabilities - HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler); + HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); // We are called for each token, but pos starts at 1. Clamping // max_generated_tokens to prompt.size() should prevent overrun. HWY_ASSERT(pos < prompt.size()); const int token = prompt[pos]; - const float prob = probs[token]; + const float prob = logits[token]; cross_entropy -= std::max(std::log(prob), -64.0f); if (verbosity >= 4) { - LogTopK(gemma.Tokenizer(), probs, vocab_size, 10); + LogTopK(gemma.Tokenizer(), logits, 10); } if (verbosity >= 3) { printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token, @@ -139,7 +137,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, RuntimeConfig runtime = { .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, - .gen = nullptr, .verbosity = verbosity, .stream_token = stream_token, .sample_func = sample_token, diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 77efbae..26313c1 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -115,7 +115,6 @@ TEST_F(GemmaTest, Multiturn) { RuntimeConfig runtime_config{ .max_generated_tokens = 64, .temperature = 0.0f, - .gen = &s_env->MutableGen(), .verbosity = 2, .batch_stream_token = stream_token, }; diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index b6537fe..04a6e00 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -126,7 +126,6 @@ void Run(GemmaEnv& env, JsonArgs& json) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 30, .temperature = 0.0f, - .gen = &env.MutableGen(), .verbosity = env.Verbosity(), .stream_token = stream_token, }; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 193903f..f67324d 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,8 +17,8 @@ #include #include +#include #include -#include #include #include #include @@ -44,7 +44,7 @@ int main(int argc, char** argv) { for (int arg = 0; arg < argc; ++arg) { // Find a --reject flag and consume everything after it. if (strcmp(argv[arg], "--reject") == 0) { - while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); + while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); // NOLINT } } @@ -55,11 +55,6 @@ int main(int argc, char** argv) { gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); size_t generated = 0; - // Initialize random number generator - std::mt19937 gen; - std::random_device rd; // NOLINT - gen.seed(rd()); - // Tokenize instructions. std::string prompt = "Write a greeting to the world."; const std::vector tokens = @@ -84,7 +79,6 @@ int main(int argc, char** argv) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 1024, .temperature = 1.0, - .gen = &gen, .verbosity = 0, .stream_token = stream_token, .accept_token = diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 7800233..e5bb1d8 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -38,11 +37,7 @@ class SimplifiedGemma { : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_), - kv_cache_(gemma_.Config(), inference, ctx_.allocator) { - // Initialize random number generator - std::random_device rd; - gen_.seed(rd()); - } + kv_cache_(gemma_.Config(), inference, ctx_.allocator) {} SimplifiedGemma(int argc, char** argv) : SimplifiedGemma(gcpp::LoaderArgs(argc, argv), @@ -76,7 +71,6 @@ class SimplifiedGemma { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = max_generated_tokens, .temperature = temperature, - .gen = &gen_, .verbosity = 0, .stream_token = stream_token, .accept_token = @@ -93,6 +87,5 @@ class SimplifiedGemma { gcpp::MatMulEnv env_; gcpp::Gemma gemma_; gcpp::KVCache kv_cache_; - std::mt19937 gen_; std::string validation_error_; }; diff --git a/gemma/api_server.cc b/gemma/api_server.cc index 70b3115..ea5377d 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -60,18 +60,18 @@ struct ServerState { std::unique_ptr gemma; MatMulEnv* env; ThreadingContext* ctx; - + // Session-based KV cache storage struct Session { std::unique_ptr kv_cache; size_t abs_pos = 0; std::chrono::steady_clock::time_point last_access; }; - + std::unordered_map sessions; std::mutex sessions_mutex; std::mutex inference_mutex; - + // Cleanup old sessions after 30 minutes of inactivity void CleanupOldSessions() { std::lock_guard lock(sessions_mutex); @@ -84,7 +84,7 @@ struct ServerState { } } } - + // Get or create session with KV cache Session& GetOrCreateSession(const std::string& session_id) { std::lock_guard lock(sessions_mutex); @@ -101,24 +101,25 @@ struct ServerState { std::string GenerateSessionId() { static std::atomic counter{0}; std::stringstream ss; - ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count() - << "_" << counter.fetch_add(1); + ss << "session_" << std::hex + << std::chrono::steady_clock::now().time_since_epoch().count() << "_" + << counter.fetch_add(1); return ss.str(); } // Wraps messages with start_of_turn markers - handles both with and without roles std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string prompt; - + for (const auto& content : contents) { if (content.contains("parts")) { // Check if role is specified (public API format) or not (local format) std::string role = content.value("role", ""); - + for (const auto& part : content["parts"]) { if (part.contains("text")) { std::string text = part["text"]; - + if (role == "user") { prompt += "user\n" + text + "\nmodel\n"; } else if (role == "model") { @@ -131,24 +132,23 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) { } } } - + return prompt; } // Parse generation config -RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) { +RuntimeConfig ParseGenerationConfig(const json& request) { RuntimeConfig config; - config.gen = &gen; config.verbosity = 0; - + // Set defaults matching public API config.temperature = 1.0f; config.top_k = 1; config.max_generated_tokens = 8192; - + if (request.contains("generationConfig")) { auto& gen_config = request["generationConfig"]; - + if (gen_config.contains("temperature")) { config.temperature = gen_config["temperature"].get(); } @@ -159,7 +159,7 @@ RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) { config.max_generated_tokens = gen_config["maxOutputTokens"].get(); } } - + return config; } @@ -175,12 +175,12 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) }}}, {"promptFeedback", {{"safetyRatings", json::array()}}} }; - + // Only add finishReason for non-streaming chunks if (!is_streaming_chunk) { response["candidates"][0]["finishReason"] = "STOP"; } - + return response; } @@ -188,11 +188,11 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { try { json request = json::parse(req.body); - + // Get or create session std::string session_id = request.value("sessionId", GenerateSessionId()); auto& session = state.GetOrCreateSession(session_id); - + // Extract prompt from API format std::string prompt; if (request.contains("contents")) { @@ -202,32 +202,29 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); return; } - + // Lock for inference std::lock_guard lock(state.inference_mutex); - + // Set up runtime config - std::mt19937 gen; - RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); - + RuntimeConfig runtime_config = ParseGenerationConfig(request); + // Collect full response std::string full_response; runtime_config.stream_token = [&full_response](int token, float) { // Skip EOS token return true; }; - + // Tokenize prompt - std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), - state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, - session.abs_pos, - prompt); - + std::vector tokens = WrapAndTokenize( + state.gemma->Tokenizer(), state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, session.abs_pos, prompt); + // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; - + // Temporarily redirect output to capture response std::stringstream output; runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { @@ -236,25 +233,25 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques session.abs_pos++; return true; } - + session.abs_pos++; - + // Check for EOS if (state.gemma->Config().IsEOS(token)) { return true; } - + // Decode token std::string token_text; state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); output << token_text; - + return true; }; - - state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, - *session.kv_cache, *state.env, timing_info); - + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, + *session.kv_cache, *state.env, timing_info); + // Create response json response = CreateAPIResponse(output.str(), false); response["usageMetadata"] = { @@ -262,17 +259,22 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques {"candidatesTokenCount", session.abs_pos - tokens.size()}, {"totalTokenCount", session.abs_pos} }; - + res.set_content(response.dump(), "application/json"); - + } catch (const json::exception& e) { res.status = 400; - res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), - "application/json"); + res.set_content( + json{{"error", + {{"message", std::string("JSON parsing error: ") + e.what()}}}} + .dump(), + "application/json"); } catch (const std::exception& e) { res.status = 500; - res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(), - "application/json"); + res.set_content( + json{{"error", {{"message", std::string("Server error: ") + e.what()}}}} + .dump(), + "application/json"); } } @@ -280,11 +282,11 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { try { json request = json::parse(req.body); - + // Get or create session std::string session_id = request.value("sessionId", GenerateSessionId()); auto& session = state.GetOrCreateSession(session_id); - + // Extract prompt from API format std::string prompt; if (request.contains("contents")) { @@ -294,13 +296,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); return; } - + // Set up SSE headers res.set_header("Content-Type", "text/event-stream"); res.set_header("Cache-Control", "no-cache"); res.set_header("Connection", "keep-alive"); res.set_header("X-Session-Id", session_id); - + // Set up chunked content provider for SSE res.set_chunked_content_provider( "text/event-stream", @@ -309,18 +311,15 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& // Lock for inference std::lock_guard lock(state.inference_mutex); auto& session = state.GetOrCreateSession(session_id); - + // Set up runtime config - std::mt19937 gen; - RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); - + RuntimeConfig runtime_config = ParseGenerationConfig(request); + // Tokenize prompt - std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), - state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, - session.abs_pos, - prompt); - + std::vector tokens = WrapAndTokenize( + state.gemma->Tokenizer(), state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, session.abs_pos, prompt); + // Stream token callback std::string accumulated_text; auto stream_token = [&](int token, float) { @@ -329,37 +328,38 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& session.abs_pos++; return true; } - + session.abs_pos++; - + // Check for EOS if (state.gemma->Config().IsEOS(token)) { return true; } - + // Decode token std::string token_text; state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); accumulated_text += token_text; - + // Send SSE event using unified formatter json event = CreateAPIResponse(token_text, true); - + std::string sse_data = "data: " + event.dump() + "\n\n"; sink.write(sse_data.data(), sse_data.size()); - + return true; }; - + runtime_config.stream_token = stream_token; - + // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; - - state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, - *session.kv_cache, *state.env, timing_info); - + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, + prefix_end, *session.kv_cache, *state.env, + timing_info); + // Send final event using unified formatter json final_event = CreateAPIResponse("", false); final_event["usageMetadata"] = { @@ -367,18 +367,18 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& {"candidatesTokenCount", session.abs_pos - tokens.size()}, {"totalTokenCount", session.abs_pos} }; - + std::string final_sse = "data: " + final_event.dump() + "\n\n"; sink.write(final_sse.data(), final_sse.size()); - + // Send done event sink.write("data: [DONE]\n\n", 15); - + // Ensure all data is sent sink.done(); - + return false; // End streaming - + } catch (const std::exception& e) { json error_event = {{"error", {{"message", e.what()}}}}; std::string error_sse = "data: " + error_event.dump() + "\n\n"; @@ -387,11 +387,14 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& } } ); - + } catch (const json::exception& e) { res.status = 400; - res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), - "application/json"); + res.set_content( + json{{"error", + {{"message", std::string("JSON parsing error: ") + e.what()}}}} + .dump(), + "application/json"); } } @@ -410,7 +413,7 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const {"topK", 1} }}} }; - + res.set_content(response.dump(), "application/json"); } @@ -419,40 +422,40 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const // server_running = false; // } -void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, +void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) { std::cerr << "Loading model..." << std::endl; - + // Initialize model ThreadingContext ctx(threading); MatMulEnv env(ctx); - + ServerState state; state.gemma = std::make_unique(loader, inference, ctx); state.env = &env; state.ctx = &ctx; - + httplib::Server server; - + // Set up routes server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); }); - + // API endpoints server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { HandleListModels(state, inference, req, res); }); - + std::string model_endpoint = "/v1beta/models/" + inference.model; server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { HandleGenerateContentNonStreaming(state, req, res); }); - + server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { HandleGenerateContentStreaming(state, req, res); }); - + // Periodic cleanup of old sessions std::thread cleanup_thread([&state]() { while (server_running) { @@ -460,18 +463,18 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, state.CleanupOldSessions(); } }); - + std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Endpoints:" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; std::cerr << " GET /v1beta/models" << std::endl; - + if (!server.listen("0.0.0.0", inference.port)) { std::cerr << "Failed to start server on port " << inference.port << std::endl; } - + cleanup_thread.join(); } @@ -479,11 +482,11 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, int main(int argc, char** argv) { gcpp::InternalInit(); - + gcpp::LoaderArgs loader(argc, argv); gcpp::ThreadingArgs threading(argc, argv); gcpp::InferenceArgs inference(argc, argv); - + if (gcpp::HasHelp(argc, argv)) { std::cerr << "\n\nAPI server for gemma.cpp\n"; std::cout << "========================\n\n"; @@ -501,14 +504,14 @@ int main(int argc, char** argv) { std::cerr << "\n"; return 0; } - + // Arguments are now handled by InferenceArgs - + // // Set up signal handler // signal(SIGINT, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown); - + gcpp::RunServer(loader, threading, inference); - + return 0; } diff --git a/gemma/attention.cc b/gemma/attention.cc index 21e5019..8afd561 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -155,8 +155,9 @@ void SingleDotSoftmaxWeightedSum( // SoftMax with optional SoftCap yields "probabilities" in att. const size_t att_len = HWY_MIN(last_pos + 1, seq_len); - MaybeLogitsSoftCap(att_cap, att, att_len, p, worker); - Softmax(att, att_len, p, worker, /*temperature=*/1.0f); + const Logits logits(att, att_len); + MaybeLogitsSoftCap(att_cap, logits, p, worker); + Softmax(logits, p, worker, /*temperature=*/1.0f); WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p, worker); diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index a6ebe30..5741d70 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -23,7 +23,6 @@ #include #include -#include "evals/benchmark_helper.h" // InitGenerator #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "gemma/tokenizer.h" // WrapAndTokenize @@ -135,8 +134,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string, std::stringstream ss; result_buffer.clear(); - InitGenerator(inference_args, gen); - // Ensure we have an active conversation if (!active_conversation || !active_conversation->kv_cache) { LogDebug("Generate called with null active_conversation or kv_cache"); @@ -174,8 +171,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // set up runtime config TimingInfo timing_info = {}; - RuntimeConfig runtime_config = {.gen = &gen, - .stream_token = stream_token, + RuntimeConfig runtime_config = {.stream_token = stream_token, .use_spinning = threading_args.spin}; inference_args.CopyTo(runtime_config); size_t prefix_end = 0; @@ -256,7 +252,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // If not multiturn, or Paligemma (which handles turns differently), // reset the *active* conversation's position. active_conversation->abs_pos = 0; - InitGenerator(inference_args, gen); } else { // Multi-turn Gemma: Rewind position in the active conversation // The last token was either EOS, then it should be ignored because it is diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index 859a644..00648fc 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -17,7 +17,6 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ #include // For std::shared_ptr, std::make_shared -#include #include #include #include @@ -107,10 +106,6 @@ class GemmaContext { // Set deterministic flag void SetDeterministic(bool value) { inference_args.deterministic = value; - // Reset the random number generator for deterministic generation - if (value) { - gen.seed(0x87654321); - } LogDebug("Setting deterministic flag to configured value"); } @@ -289,9 +284,6 @@ class GemmaContext { // Model itself (don't move this, needs to be below the args above) Gemma model; - // Random generator (remains global for the context) - std::mt19937 gen; - // Static members for logging static GemmaLogCallback s_log_callback; static void* s_log_user_data; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a7e73ca..0177c92 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -440,8 +440,7 @@ static void SampleAndStream( // TODO: parallelize non_eos.Foreach([&](size_t qi) { - float* HWY_RESTRICT logits = activations.logits.Row(qi); - const TokenAndProb tp = sample_token(logits, config.vocab_size); + const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi)); // We streamed all prefill tokens, but pos is still one behind because we // started generation at pos = prompt.size() - 1. We want the pos argument @@ -453,7 +452,8 @@ static void SampleAndStream( } static HWY_INLINE SampleFunc -ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) { +ChooseSampleFunc(const RuntimeConfig& runtime_config, + const AesCtrEngine& engine, ThreadingContext& ctx) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; @@ -462,27 +462,28 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) { // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb { PROFILER_ZONE3(ctx.profiler, worker, zone); - return Top1OfSoftmax(logits, vocab_size); + return Top1OfSoftmax(logits); }; } // General case: Softmax with top-k sampling. - return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); + RngStream gen(engine, qi); return FusedSoftmaxAndSampleTopK( - logits, runtime_config.top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token, ctx.profiler, - worker); + logits, runtime_config.top_k, gen, runtime_config.temperature, + runtime_config.accept_token, ctx.profiler, worker); }; } // Decode: generates one continuation token for each query in `qbatch`. static void GenerateT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, Activations& activations, - QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { + const AesCtrEngine& engine, const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, MatMulEnv& env, + TimingInfo& timing_info) { // Griffin assumes that the recurrent block cache is zero-initialized. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { if (qbatch.MutablePos(qi) == 0) { @@ -554,7 +555,8 @@ static void GenerateT(const ModelConfig& config, max_gen_steps = seq_len - max_prompt_size; } - const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); + const SampleFunc sample_token = + ChooseSampleFunc(runtime_config, engine, env.ctx); timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { @@ -568,15 +570,16 @@ static void GenerateT(const ModelConfig& config, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, KVCache& kv_cache, - MatMulEnv& env, TimingInfo& timing_info) { + const AesCtrEngine& engine, const WeightsPtrs& weights, + KVCache& kv_cache, MatMulEnv& env, + TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries); - GenerateT(config, runtime_config, weights, activations, qbatch, env, + GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, timing_info); } @@ -584,8 +587,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, // queries, and calls `GenerateT` on each batch. void GenerateBatchT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, AllQueries& all_queries, - MatMulEnv& env, TimingInfo& timing_info) { + const AesCtrEngine& engine, const WeightsPtrs& weights, + AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) { const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); Activations activations(config, max_batch_size, @@ -596,7 +600,7 @@ void GenerateBatchT(const ModelConfig& config, start += runtime_config.decode_qbatch_size) { QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); // Generate a batch of one token for each of `qbatch.Size()` queries. - GenerateT(config, runtime_config, weights, activations, qbatch, env, + GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, timing_info); } } @@ -637,7 +641,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), - inference_(inference) { + inference_(inference), + aes_ctr_engine_(inference.deterministic) { // Negligible CPU time in the ctor body (except ReadFromBlobs). weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx); @@ -661,9 +666,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, TimingInfo& timing_info) const { env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end, - model_.Config(), runtime_config, - weights_, kv_cache, env, timing_info); + HWY_DYNAMIC_DISPATCH(GenerateSingleT)( + prompt, pos, prefix_end, model_.Config(), runtime_config, aes_ctr_engine_, + weights_, kv_cache, env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -674,7 +679,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config, - weights_, all_queries, env, timing_info); + aes_ctr_engine_, weights_, all_queries, + env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 0f9aae2..2f06ab8 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -278,6 +278,7 @@ class Gemma { WeightsPtrs::Mode weight_read_mode_; GemmaChatTemplate chat_template_; InferenceArgs inference_; + AesCtrEngine aes_ctr_engine_; }; } // namespace gcpp diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 2a49349..59e3a6c 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -22,7 +22,6 @@ #include #include -#include #include #include "io/io.h" // Path @@ -90,10 +89,10 @@ using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; -// If not empty, SampleFunc is called with the logits for the next token, which -// it may modify/overwrite, and its return value is the next generated token -// together with its probability. -using SampleFunc = std::function; +// If not empty, SampleFunc is called with the query_idx and logits for the +// next token, which it may modify/overwrite. It returns the next generated +// token together with its probability. +using SampleFunc = std::function; // If not empty, LayersOutputFunc is called for layer outputs, specified with: // - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence @@ -136,8 +135,7 @@ struct RuntimeConfig { // Sampling-related parameters. float temperature; // Temperature for sampling. - size_t top_k = 1; // Top-k for sampling. - std::mt19937* gen; // Random number generator used for sampling. + size_t top_k = 1; // Top-k for sampling. int verbosity; // Controls verbosity of printed messages. diff --git a/gemma/run.cc b/gemma/run.cc index 3915bf8..7e2059f 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -98,9 +97,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, size_t prompt_size = 0; const ModelConfig& config = gemma.Config(); - std::mt19937 gen; - InitGenerator(inference, gen); - const bool have_image = !inference.image_file.path.empty(); Image image; const size_t pool_dim = config.vit_config.pool_dim; @@ -117,8 +113,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, HWY_ASSERT(image.ReadPPM(inference.image_file.path)); const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, + RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .use_spinning = threading.spin}; double image_tokens_start = hwy::platform::Now(); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, @@ -188,8 +183,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // Set up runtime config. TimingInfo timing_info = {.verbosity = inference.verbosity}; - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, + RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .batch_stream_token = batch_stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); @@ -239,7 +233,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // Prepare for the next turn. Works only for PaliGemma. if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) { abs_pos = 0; // Start a new turn at position 0. - InitGenerator(inference, gen); } else { // The last token was either EOS, then it should be ignored because it is // never part of the dialog, see Table 5 in the Gemma-2 paper: diff --git a/gemma/vit.cc b/gemma/vit.cc index 96d6d7f..1910091 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -110,8 +110,7 @@ class VitAttention { CallMatMul(Q, K, nullptr, env_, C); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { - float* HWY_RESTRICT c = C.Row(task); - Softmax(c, C.Cols(), env_.ctx.profiler, worker); + Softmax(C.RowSpan(task), env_.ctx.profiler, worker); }); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { @@ -154,7 +153,7 @@ class VitAttention { head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len, env_.ctx.profiler, worker); + Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = activations_.attention.att_out.Row(token) + head * qkv_dim; diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 8afb220..3cb565c 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -812,7 +812,7 @@ class DotStats { // Forward relative error, lower is better. void CheckRel() const { - ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 4E-3); + ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 7E-3); ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f); // Compensated and Double are very accurate. @@ -822,22 +822,22 @@ class DotStats { ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f); // Naive and OnlyTwoProd are considerably higher, but not huge. - ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2); + ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 3.5E-1); ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(), - 0.072); + 7.5E-2); // Kahan (FastTwoSum) is decent: - ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3); + ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 1E-2); ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f); // TwoProducts and TwoSums are a bit better. ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(), - 3E-3); - ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f); + 1.1E-2); + ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 1.0f); ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(), - 2.6E-3); + 1.1E-2); - ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2); + ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 5.2E-2); // Extremely high error on aarch64. ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f); } @@ -857,7 +857,7 @@ class DotStats { ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f); // But TwoProducts/TwoSums help a bit. - ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f); + ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 1.0f); ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f); // Extremely high error on aarch64. @@ -893,7 +893,7 @@ class DotStats { }; // Returns normalized value in [-1, 1). -float RandomFloat(std::mt19937& rng) { +float RandomFloat(RngStream& rng) { const uint32_t exp = hwy::BitCastScalar(1.0f); const uint32_t mantissa_mask = hwy::MantissaMask(); const uint32_t representation = exp | (rng() & mantissa_mask); @@ -908,7 +908,7 @@ float RandomFloat(std::mt19937& rng) { // error from the Dot algorithms, not the compression. template void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw, - std::mt19937& rng, + RngStream& rng, const PackedSpan& packed, CompressWorkingSet& work) { std::uniform_int_distribution e_dist(0, 6); @@ -934,7 +934,7 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw, // Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf. template double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v, - std::mt19937& rng) { + RngStream& rng) { PROFILER_FUNC; const size_t half = HWY_MAX(1, num / 2); // generate at least one random HWY_DASSERT(half != 0); @@ -1002,8 +1002,8 @@ struct TestShortDotsT { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); CompressWorkingSet work; - std::mt19937 rng; - rng.seed(12345); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); hwy::Stats s_l1[kVariants]; @@ -1108,9 +1108,10 @@ void TestAllDot() { { // ensure no profiler zones are active const hn::ScalableTag df; - std::mt19937 rngs[kMaxWorkers]; + AesCtrEngine engine(/*deterministic=*/true); + RngStream rngs[kMaxWorkers]; for (size_t i = 0; i < kMaxWorkers; ++i) { - rngs[i].seed(12345 + 65537 * i); + rngs[i] = RngStream(engine, i); } constexpr size_t kReps = hn::AdjustedReps(40); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 19a39aa..18ee40f 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -29,7 +29,7 @@ #include "ops/matmul.h" #include "util/allocator.h" -#include "util/basics.h" // TokenAndProb +#include "util/basics.h" // TokenAndProb, RngStream #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" @@ -614,12 +614,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( } // See below for a specialized version for top-1 sampling. -static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - hwy::Profiler& p, const size_t worker, +static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, + const size_t worker, float temperature = 1.0f) { static const auto zone = p.AddZone("Ops.Softmax"); PROFILER_ZONE3(p, worker, zone); - HWY_DASSERT(size != 0); + HWY_DASSERT(logits.size() != 0); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -629,24 +629,25 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const V vmin = hn::Set(d, hwy::LowestValue()); V vmax = vmin; V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly - hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR { - *pmax = hn::Max(*pmax, value); - }); + hn::Foreach(d, logits.data(), logits.size(), vmin, + [pmax](const auto d, const V value) + HWY_ATTR { *pmax = hn::Max(*pmax, value); }); vmax = hn::MaxOfLanes(d, vmax); // Subtract max (avoid precision loss for large exponents) and exponentiate. - hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR { - if constexpr (HWY_TARGET & HWY_ALL_SVE) { - // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). - return hn::CallExp(d, hn::Sub(value, *pmax)); - } else { - return hn::Exp(d, hn::Sub(value, *pmax)); - } - }); + hn::Transform(d, logits.data(), logits.size(), + [pmax](const auto d, const V value) HWY_ATTR { + if constexpr (HWY_TARGET & HWY_ALL_SVE) { + // Workaround for buggy SVE codegen: avoid inlined Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); + } else { + return hn::Exp(d, hn::Sub(value, *pmax)); + } + }); if (temperature != 1.0f) { const float temperature_inv = 1.0f / temperature; - hn::Transform(d, x, size, + hn::Transform(d, logits.data(), logits.size(), [temperature_inv](const auto d, const V value) HWY_ATTR { return hn::Mul(value, hn::Set(d, temperature_inv)); }); @@ -656,10 +657,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // not make a huge difference. It halves the standard deviation of the sum of // the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the generated text after a few hundred tokens. - const float sum_exp = Sum(d, x, size); + const float sum_exp = Sum(d, logits.data(), logits.size()); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, x, size, p, worker); + MulByConst(mul, logits.data(), logits.size(), p, worker); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / @@ -669,8 +670,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // which already knows the max value which top-1 sampling would again seek. // Returns the argmax and x[argmax]. -static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, - const size_t num) { +static HWY_INLINE TokenAndProb ArgmaxAndMax(Logits logits) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; using V = hn::Vec; @@ -680,16 +680,16 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, using TI = hn::TFromD; using VI = hn::Vec; const size_t N = hn::Lanes(d); - HWY_ASSERT(num % (2 * N) == 0); + HWY_ASSERT(logits.size() % (2 * N) == 0); V max0 = hn::Set(d, hwy::LowestValue()); V max1 = max0; VI argmax0 = hn::Zero(di); VI argmax1 = argmax0; - for (size_t i = 0; i < num; i += 2 * N) { - const V v0 = hn::LoadU(d, x + i); - const V v1 = hn::LoadU(d, x + i + N); + for (size_t i = 0; i < logits.size(); i += 2 * N) { + const V v0 = hn::LoadU(d, &logits[i]); + const V v1 = hn::LoadU(d, &logits[i + N]); const VI vi0 = hn::Iota(di, static_cast(i)); const VI vi1 = hn::Iota(di, static_cast(i + N)); const M gt0 = hn::Gt(v0, max0); @@ -714,43 +714,43 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)}; } -// Returns argmax of softmax and its probability. This overwrites `x`, but not -// with normalized probabilities. Only equivalent to `Softmax` + `sample_func` -// if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize` -// == 256K, and this avoids writing and then scanning again for the max. -// However, this is not enough to make parallelization worthwhile. -static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, - const size_t num) { +// Returns argmax of softmax and its probability. This overwrites `logits`, but +// not with normalized probabilities. Only equivalent to `Softmax` + +// `sample_func` if `kTopK` == 1. This is worthwhile because `logits.size()` is +// typically `kVocabSize == 256K`, and this avoids writing and then scanning +// again for the max. +static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) { namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag d; using V = hn::Vec; - const TokenAndProb argmax = ArgmaxAndMax(x, num); + const TokenAndProb argmax = ArgmaxAndMax(logits); // Subtract max (avoid precision loss for large exponents) and exponentiate. const V max = hn::Set(d, argmax.prob); const V* pmax = &max; - hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR { - if constexpr (HWY_TARGET & HWY_ALL_SVE) { - // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). - return hn::CallExp(d, hn::Sub(value, *pmax)); - } else { - return hn::Exp(d, hn::Sub(value, *pmax)); - } - }); + hn::Transform(d, logits.data(), logits.size(), + [pmax](const auto d, const V value) HWY_ATTR { + if constexpr (HWY_TARGET & HWY_ALL_SVE) { + // Temporary workaround for buggy SVE codegen: avoid inlined + // Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); + } else { + return hn::Exp(d, hn::Sub(value, *pmax)); + } + }); // Normalize to a single probability. The exact sum seems like it should not // make a huge difference. It halves the standard deviation of the sum of the // normalized probabilities from 1E-7 to 5E-8, but actually also changes the // generated text after a few hundred tokens. - const float sum_exp = Sum(d, x, num); - const float prob = x[argmax.token] / sum_exp; + const float sum_exp = Sum(d, logits.data(), logits.size()); + const float prob = logits[argmax.token] / sum_exp; return TokenAndProb{.token = argmax.token, .prob = prob}; } -static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - const size_t size, hwy::Profiler& p, - const size_t worker) { +static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, + hwy::Profiler& p, const size_t worker) { static const auto zone = p.AddZone("Ops.LogitsSoftCap"); PROFILER_ZONE3(p, worker, zone); @@ -763,18 +763,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, const VF* HWY_RESTRICT pcap = &vcap; const VF* HWY_RESTRICT pinv_cap = &vinv_cap; - DecompressAndCompressInplace( - DF(), x, size, [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF { - return hn::Mul(*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap))); - }); + DecompressAndCompressInplace(DF(), logits.data(), logits.size(), + [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF { + return hn::Mul( + *pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap))); + }); } // Calls LogitsSoftCap if cap != 0.0f. static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( - const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, - const size_t worker) { + const float cap, Logits logits, hwy::Profiler& p, const size_t worker) { if (cap != 0.0f) { - LogitsSoftCap(cap, x, size, p, worker); + LogitsSoftCap(cap, logits, p, worker); } } @@ -785,20 +785,18 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { - LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, - worker); + LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker); } }); } -static HWY_NOINLINE HWY_MAYBE_UNUSED size_t -SampleArgmax(const float* probabilities, size_t vocab_size) { +static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(Logits logits) { size_t max_index = 0; - float max_prob = probabilities[0]; - for (size_t i = 1; i < vocab_size; ++i) { - if (probabilities[i] > max_prob) { + float max_prob = logits[0]; + for (size_t i = 1; i < logits.size(); ++i) { + if (logits[i] > max_prob) { max_index = i; - max_prob = probabilities[i]; + max_prob = logits[i]; } } return max_index; @@ -828,16 +826,15 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution( template HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( - const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k, - TAcceptToken& accept_token) { + Logits logits, size_t k, TAcceptToken& accept_token) { HWY_ASSERT(k != 0); - HWY_ASSERT(k <= vocab_size); + HWY_ASSERT(k <= logits.size()); std::vector packed_token_probs; - for (int32_t i = 0; i < static_cast(vocab_size); ++i) { - if (accept_token && !accept_token(i, probabilities[i])) { + for (int32_t i = 0; i < static_cast(logits.size()); ++i) { + if (accept_token && !accept_token(i, logits[i])) { continue; } - packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i])); + packed_token_probs.push_back(PackTokenAndProb(i, logits[i])); } hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k, @@ -853,11 +850,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( } template -HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( - const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token) { - std::vector token_probs = - TopK(probabilities, vocab_size, k, accept_token); +HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k, + RngStream& gen, float temperature, + TAcceptToken& accept_token) { + std::vector token_probs = TopK(logits, k, accept_token); std::vector topk_indices(k); std::vector topk_probs(k); for (size_t i = 0; i < k; ++i) { @@ -869,14 +865,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( template HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( - const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token, - hwy::Profiler& p, size_t worker) { + Logits logits, size_t k, RngStream& gen, float temperature, + TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) { // Softmax and sample top-K is equivalent to taking the top-K logits and // sampling from the softmax of the top-K logits. The latter is faster as it // avoids computing the softmax of all logits. - std::vector token_logits = - TopK(logits, vocab_size, k, accept_token); + std::vector token_logits = TopK(logits, k, accept_token); std::vector topk_indices(k); std::vector topk_logits(k); for (size_t i = 0; i < token_logits.size(); ++i) { @@ -884,8 +878,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( topk_logits[i] = token_logits[i].prob; } - size_t mask = token_logits.size(); - Softmax(topk_logits.data(), mask, p, worker, temperature); + const size_t mask = token_logits.size(); + Softmax(Logits(topk_logits.data(), mask), p, worker, temperature); auto distribution = std::discrete_distribution( std::begin(topk_logits), std::begin(topk_logits) + mask); int topk_sampled_index = distribution(gen); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 7e63482..213fdd0 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -57,6 +57,12 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +static RngStream MakeRng() { + static AesCtrEngine engine(/*deterministic=*/true); + static uint64_t stream = 0; + return RngStream(engine, ++stream); +} + template struct ForeachCountAndMisalign { template @@ -304,7 +310,7 @@ class TestSoftmax { } SimpleSoftmax(e, count); - Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0); + Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0); T sum = 0.0f; for (size_t i = 0; i < count; ++i) { @@ -438,10 +444,9 @@ void TestRopeAndMulBy() { const size_t dim_qkv = config.layer_configs[0].qkv_dim; MatStorageT x("x", dim_qkv, ctx.allocator); - std::mt19937 gen; - gen.seed(0x12345678); + RngStream rng = MakeRng(); std::normal_distribution r{0.0, 5.0}; - auto random_float = [&r, &gen] { return r(gen); }; + auto random_float = [&r, &rng] { return r(rng); }; for (size_t i = 0; i < dim_qkv; ++i) { x.Row(0)[i] = random_float(); @@ -704,38 +709,34 @@ void TestSampleTopK() { hwy::Profiler& p = hwy::Profiler::Get(); const size_t worker = 0; const size_t kSize = 52; - std::vector logits(kSize); + std::vector logits_vec(kSize); + Logits logits(logits_vec.data(), kSize); // Create a vector going from -100 to -100+51=49 and take Softmax. std::iota(logits.begin(), logits.end(), -100.0f); - Softmax(logits.data(), kSize, p, worker); - std::mt19937 gen; - gen.seed(0x12345678); + Softmax(logits, p, worker); + RngStream rng = MakeRng(); float temperature = 1.0f; // SampleTopK<1> should return the argmax. std::function accept_token; - int sample = - SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token); + int sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token); EXPECT_EQ(sample, 51); // Last is largest. // Only accept even tokens, expect the last (largest) even index. accept_token = [](int i, float) { return i % 2 == 0; }; - sample = - SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token); + sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token); EXPECT_EQ(sample, 50); // Last even index. // Reset the logits to a positive, increasing sequence and take Softmax. std::iota(logits.begin(), logits.end(), 1.0f); - Softmax(logits.data(), kSize, p, worker); + Softmax(logits, p, worker); // Sample from the top 3, expect one of the top 3 even indices. for (int i = 0; i < 100; ++i) { - sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, - accept_token); + sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token); EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46); } // Now set the temperature to 0.0f, which should always return the argmax, // even for k=3. temperature = 0.0f; for (int i = 0; i < 100; ++i) { - sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, - accept_token); + sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token); EXPECT_EQ(sample, 50); } } diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index 2c798b9..449ee00 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -27,42 +27,38 @@ void PaliGemmaHelper::InitVit(const std::string& path) { HWY_ASSERT(image.ReadPPM(path)); const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.gen = &env_->MutableGen(), - .verbosity = 0}; + RuntimeConfig runtime_config = {.verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(), image, *image_tokens_, env_->MutableEnv()); } std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { const Gemma& model = *(env_->GetGemma()); - env_->MutableGen().seed(0x12345678); - std::string response; - auto stream_token = [&](int token, float) { - std::string token_text; - HWY_ASSERT( - model.Tokenizer().Decode(std::vector{token}, &token_text)); - response += token_text; - return true; - }; + std::string response; + auto stream_token = [&](int token, float) { + std::string token_text; + HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); + response += token_text; + return true; + }; - std::string mutable_prompt = prompt_text; - std::vector tokens = env_->WrapAndTokenize(mutable_prompt); - tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); + std::string mutable_prompt = prompt_text; + std::vector tokens = env_->WrapAndTokenize(mutable_prompt); + tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); - RuntimeConfig runtime_config = {.max_generated_tokens = 512, - // PrefixLM sees/attends to all tokens. - .prefill_tbatch_size = tokens.size(), - .gen = &env_->MutableGen(), - .verbosity = 0, - .stream_token = stream_token, - .image_tokens = image_tokens_.get()}; + RuntimeConfig runtime_config = {.max_generated_tokens = 512, + // PrefixLM sees/attends to all tokens. + .prefill_tbatch_size = tokens.size(), + .verbosity = 0, + .stream_token = stream_token, + .image_tokens = image_tokens_.get()}; - const size_t prefix_end = tokens.size(); - TimingInfo timing_info = {.verbosity = 0}; - model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, - env_->MutableKVCache(), env_->MutableEnv(), timing_info); - return response; + const size_t prefix_end = tokens.size(); + TimingInfo timing_info = {.verbosity = 0}; + model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, + env_->MutableKVCache(), env_->MutableEnv(), timing_info); + return response; } } // namespace gcpp diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 9af07b3..2e39f68 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -53,9 +53,8 @@ class GemmaModel { // Generates a single example, given a prompt and a callback to stream the // generated tokens. void GenerateEx(std::string prompt, gcpp::StreamFunc stream, - size_t max_generated_tokens, float temperature, float seed, - gcpp::AcceptFunc accept, bool skip_prompt) { - env_.MutableGen().seed(seed); + size_t max_generated_tokens, float temperature, + float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) { std::vector prompt_tokens = env_.WrapAndTokenize(prompt); gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; @@ -77,7 +76,7 @@ class GemmaModel { // Generates a single example, given a prompt, and returns the result. std::string Generate(std::string prompt, size_t max_generated_tokens, - float temperature, float seed, + float temperature, float /*seed*/, const std::vector& accept, const std::vector& end) { std::set end_token_set{}; @@ -124,7 +123,6 @@ class GemmaModel { } }; - env_.MutableGen().seed(seed); gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; @@ -144,14 +142,13 @@ class GemmaModel { // results. std::vector GenerateBatch(const std::vector& inputs, size_t max_generated_tokens, - float temperature, float seed, + float temperature, float /*seed*/, size_t top_k) { gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; config.top_k = top_k; config.verbosity = 0; - env_.MutableGen().seed(seed); std::vector outputs = env_.BatchQueryModel(inputs); std::vector result; @@ -187,8 +184,7 @@ class GemmaModel { "image_tokens", gcpp::Extents2D(config.vit_config.seq_len, config.model_dim), env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd)); - gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(), - .verbosity = 0}; + gcpp::RuntimeConfig runtime_config = {.verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(), c_image, *image_tokens_, env_.MutableEnv()); } @@ -197,10 +193,9 @@ class GemmaModel { // Uses the prompt_tokens if provided, otherwise tokenizes the prompt string. std::pair> GenerateWithImage( std::string prompt, size_t max_generated_tokens, float temperature, - float seed, gcpp::AcceptFunc accept, std::vector prompt_tokens) { + float /*seed*/, gcpp::AcceptFunc accept, std::vector prompt_tokens) { if (!image_tokens_) throw std::invalid_argument("No image set."); const gcpp::Gemma& model = *env_.GetGemma(); - env_.MutableGen().seed(seed); gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; @@ -273,6 +268,7 @@ PYBIND11_MODULE(gemma, mod) { }), py::arg("tokenizer_path"), py::arg("weights_path"), py::arg("max_threads") = 0) + // seed arguments are ignored. .def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"), py::arg("stream"), py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9, py::arg("seed") = 123456789, diff --git a/util/basics.cc b/util/basics.cc index 4261510..d9fbc27 100644 --- a/util/basics.cc +++ b/util/basics.cc @@ -24,7 +24,7 @@ namespace gcpp { -RNG::RNG(bool deterministic) { +AesCtrEngine::AesCtrEngine(bool deterministic) { // Pi-based nothing up my sleeve numbers from Randen. key_[0] = 0x243F6A8885A308D3ull; key_[1] = 0x13198A2E03707344ull; @@ -54,9 +54,10 @@ static V Load(const uint64_t* ptr) { return hn::Load(D(), reinterpret_cast(ptr)); } -RNG::result_type RNG::operator()() { - V state = Load(counter_); - counter_[0]++; +uint64_t AesCtrEngine::operator()(uint64_t stream, uint64_t counter) const { + const hn::Repartition d64; + + V state = hn::BitCast(D(), hn::Dup128VecFromValues(d64, counter, stream)); state = hn::Xor(state, Load(key_)); // initial whitening static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t)); @@ -68,7 +69,6 @@ RNG::result_type RNG::operator()() { state = hn::AESRound(state, Load(key_ + 10)); // Return lower 64 bits of the u8 vector. - const hn::Repartition d64; return hn::GetLane(hn::BitCast(d64, state)); } diff --git a/util/basics.h b/util/basics.h index 2429c72..7b1c7d3 100644 --- a/util/basics.h +++ b/util/basics.h @@ -20,7 +20,7 @@ #include #include -#include "hwy/aligned_allocator.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // IWYU pragma: end_exports @@ -120,39 +120,60 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end, return IndexRange(begin, HWY_MIN(begin + max_size, end)); } +using Logits = hwy::Span; // size() is vocab_size. + // Non-cryptographic 64-bit pseudo-random number generator. Supports random or -// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`. +// deterministic seeding. // // Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This // is useful for parallel sampling. Each thread can generate the stream for a // particular task, without caring about prior/subsequent generations. -class alignas(16) RNG { +class alignas(16) AesCtrEngine { // "Large-scale randomness study of security margins for 100+ cryptographic // functions": at least four. // "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant. static constexpr size_t kRounds = 5; public: - explicit RNG(bool deterministic); + // If `deterministic` is true, uses a fixed seed; otherwise, attempts to + // grab entropy from the OS. + explicit AesCtrEngine(bool deterministic); - void SetStream(uint64_t stream) { - counter_[1] = stream; - counter_[0] = 0; - } + // Pure and thread safe; typically called via `RngStream`, which increments + // `counter`. Throughput is about 100M/s on 3 GHz Skylake. It could be + // increased 4x via unrolling by the AES latency (4-7 cycles), but because + // users generally call once at a time, this requires buffering, which is not + // worth the complexity in this application. + uint64_t operator()(uint64_t stream, uint64_t counter) const; + + private: + uint64_t key_[2 * (1 + kRounds)]; +}; + +// Flyweight per-thread adapter that maintains the counter. Conforms to C++ +// `UniformRandomBitGenerator`. +class RngStream { + public: + RngStream() = default; // Allow C arrays with subsequent initialization. + + // Binds to an engine, which holds the seed and must outlive this object. + // Sets the stream; any other `RngStream` with the same `counter_rng` and + // `stream` will return the same sequence. This is typically the task ID, so + // that threads can independently generate values for each task. + RngStream(const AesCtrEngine& counter_rng, uint64_t stream) + : engine_(&counter_rng), stream_(stream), counter_(0) {} using result_type = uint64_t; static constexpr result_type min() { return 0; } static constexpr result_type max() { return ~result_type{0}; } - - // About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via - // unrolling by the AES latency (4-7 cycles). `std::discrete_distribution` - // makes individual calls to the generator, which would require buffering, - // which is not worth the complexity. - result_type operator()(); + result_type operator()() { return (*engine_)(stream_, counter_++); } private: - uint64_t counter_[2] = {}; - uint64_t key_[2 * (1 + kRounds)]; + const AesCtrEngine* engine_ = nullptr; + uint64_t stream_ = 0; // immutable after ctor + uint64_t counter_ = 0; + // Prevent false sharing if used by multiple threads. + HWY_MAYBE_UNUSED uint8_t padding_[HWY_ALIGNMENT - 16 - sizeof(engine_)]; }; } // namespace gcpp diff --git a/util/basics_test.cc b/util/basics_test.cc index 169d051..a1d805b 100644 --- a/util/basics_test.cc +++ b/util/basics_test.cc @@ -25,9 +25,11 @@ namespace gcpp { namespace { -TEST(BasicsTest, IsDeterministic) { - RNG rng1(/*deterministic=*/true); - RNG rng2(/*deterministic=*/true); +TEST(BasicsTest, EngineIsDeterministic) { + const AesCtrEngine engine1(/*deterministic=*/true); + const AesCtrEngine engine2(/*deterministic=*/true); + RngStream rng1(engine1, 0); + RngStream rng2(engine2, 0); // Remember for later testing after resetting the stream. const uint64_t r0 = rng1(); const uint64_t r1 = rng1(); @@ -42,15 +44,17 @@ TEST(BasicsTest, IsDeterministic) { HWY_ASSERT(rng1() == rng2()); } - // Reset counter, ensure it matches the default-constructed RNG. - rng1.SetStream(0); + // Reset counter, ensure it matches the prior sequence. + rng1 = RngStream(engine1, 0); HWY_ASSERT(r0 == rng1()); HWY_ASSERT(r1 == rng1()); } -TEST(BasicsTest, IsSeeded) { - RNG rng1(/*deterministic=*/true); - RNG rng2(/*deterministic=*/false); +TEST(BasicsTest, EngineIsSeeded) { + AesCtrEngine engine1(/*deterministic=*/true); + AesCtrEngine engine2(/*deterministic=*/false); + RngStream rng1(engine1, 0); + RngStream rng2(engine2, 0); // It would be very unlucky to have even one 64-bit value match, and two are // extremely unlikely. const uint64_t a0 = rng1(); @@ -60,9 +64,27 @@ TEST(BasicsTest, IsSeeded) { HWY_ASSERT(a0 != b0 || a1 != b1); } +TEST(BasicsTest, StreamsDiffer) { + AesCtrEngine engine(/*deterministic=*/true); + // Compare random streams for more coverage than just the first N streams. + RngStream rng_for_stream(engine, 0); + for (size_t i = 0; i < 1000; ++i) { + RngStream rng1(engine, rng_for_stream()); + RngStream rng2(engine, rng_for_stream()); + // It would be very unlucky to have even one 64-bit value match, and two are + // extremely unlikely. + const uint64_t a0 = rng1(); + const uint64_t a1 = rng1(); + const uint64_t b0 = rng2(); + const uint64_t b1 = rng2(); + HWY_ASSERT(a0 != b0 || a1 != b1); + } +} + // If not close to 50% 1-bits, the RNG is quite broken. TEST(BasicsTest, BitDistribution) { - RNG rng(/*deterministic=*/true); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); constexpr size_t kU64 = 2 * 1000 * 1000; const hwy::Timestamp t0; uint64_t one_bits = 0; @@ -78,7 +100,8 @@ TEST(BasicsTest, BitDistribution) { } TEST(BasicsTest, ChiSquared) { - RNG rng(/*deterministic=*/true); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); constexpr size_t kU64 = 1 * 1000 * 1000; // Test each byte separately. diff --git a/util/mat.h b/util/mat.h index 9d838e2..c084e81 100644 --- a/util/mat.h +++ b/util/mat.h @@ -301,6 +301,13 @@ class MatPtrT : public MatPtr { return HWY_RCAST_ALIGNED(const T*, RowBytes(row)); } + hwy::Span RowSpan(size_t row) { + return hwy::Span(Row(row), Cols()); + } + hwy::Span RowSpan(size_t row) const { + return hwy::Span(Row(row), Cols()); + } + PackedSpan PaddedSpan() const { return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride()); }