From 41a86d41a9b2b937faeda6cafc225d50f266a0ad Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 15 Aug 2025 06:30:07 -0700 Subject: [PATCH 01/65] Fix preadv error: only enable if we have a handle PiperOrigin-RevId: 795455020 --- io/io.cc | 30 +++++++++++++++++------------- io/io.h | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/io/io.cc b/io/io.cc index df39f7b..d19ec10 100644 --- a/io/io.cc +++ b/io/io.cc @@ -226,21 +226,26 @@ void InternalInit() { } uint64_t IOBatch::Read(const File& file) const { -#if GEMMA_IO_PREADV HWY_ASSERT(!spans_.empty()); - ssize_t bytes_read; - for (;;) { - bytes_read = - preadv(file.Handle(), reinterpret_cast(spans_.data()), - static_cast(spans_.size()), offset_); - if (bytes_read >= 0) break; - if (errno == EINTR) continue; // signal: retry - HWY_WARN("preadv failed, errno %d.", errno); - return 0; +#if GEMMA_IO_PREADV + if (file.Handle() != -1) { + ssize_t bytes_read; + for (;;) { + bytes_read = + preadv(file.Handle(), reinterpret_cast(spans_.data()), + static_cast(spans_.size()), offset_); + if (bytes_read >= 0) break; + if (errno == EINTR) continue; // signal: retry + HWY_WARN("preadv(%d) for %4zu spans from offset %12zu failed, errno %d.", + file.Handle(), spans_.size(), offset_, errno); + return 0; + } + return static_cast(bytes_read); } - return static_cast(bytes_read); -#else +#endif // GEMMA_IO_PREADV + + // preadv disabled or no handle: use normal reads (higher kernel overhead). uint64_t total = 0; uint64_t offset = offset_; for (const IOSpan& span : spans_) { @@ -249,7 +254,6 @@ uint64_t IOBatch::Read(const File& file) const { offset += span.bytes; } return total; -#endif } } // namespace gcpp diff --git a/io/io.h b/io/io.h index d9481bc..f90a636 100644 --- a/io/io.h +++ b/io/io.h @@ -68,7 +68,7 @@ class File { // modify internal state. This is only expected to be called once per file. virtual MapPtr Map() = 0; - // For use by `IOBatch::Read`. + // Returns handle for use by `IOBatch::Read`, or -1 if not supported. virtual int Handle() const { return -1; } }; From 41321611fdbc321b8175922d565ca5df5d15d65a Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 20 Aug 2025 11:05:09 +0900 Subject: [PATCH 02/65] feature: add API server and client with Google protocol --- API_SERVER_README.md | 250 +++++++++++++++++++++ CMakeLists.txt | 31 +++ gemma/api_client.cc | 359 ++++++++++++++++++++++++++++++ gemma/api_server.cc | 514 +++++++++++++++++++++++++++++++++++++++++++ gemma/gemma_args.h | 34 +++ 5 files changed, 1188 insertions(+) create mode 100644 API_SERVER_README.md create mode 100644 gemma/api_client.cc create mode 100644 gemma/api_server.cc diff --git a/API_SERVER_README.md b/API_SERVER_README.md new file mode 100644 index 0000000..f7af504 --- /dev/null +++ b/API_SERVER_README.md @@ -0,0 +1,250 @@ +# Gemma.cpp API Server + +This is an HTTP API server for gemma.cpp that implements the Google API protocol, allowing you to interact with Gemma models through REST API endpoints compatible with the Google API format. + +## Features + +- **API-compatible**: Implements Google API endpoints +- **Unified client/server**: Single codebase supports both local and public API modes +- **Text generation**: Support for `generateContent` endpoint +- **Streaming support**: Server-Sent Events (SSE) for `streamGenerateContent` +- **Model management**: Support for `/v1beta/models` endpoint +- **Session management**: Maintains conversation context with KV cache +- **JSON responses**: All responses in Google API format +- **Error handling**: Proper HTTP status codes and error messages + +## Building + +The API server is built alongside the main gemma.cpp project: + +```bash +# Configure the build +cmake -B build -DCMAKE_BUILD_TYPE=Release + +# Build the API server and client +cmake --build build --target gemma_api_server gemma_api_client -j 8 +``` + +The binaries will be created at: +- `build/gemma_api_server` - Local API server +- `build/gemma_api_client` - Unified client for both local and public APIs + +## Usage + +### Starting the Local API Server + +```bash +./build/gemma_api_server \ + --tokenizer path/to/tokenizer.spm \ + --weights path/to/model.sbs \ + --port 8080 +``` + +**Required arguments:** +- `--tokenizer`: Path to the tokenizer file (`.spm`) +- `--weights`: Path to the model weights file (`.sbs`) + +**Optional arguments:** +- `--port`: Port to listen on (default: 8080) +- `--model`: Model name for API endpoints (default: gemma3-4b) + +### Using the Unified Client + +#### With Local Server +```bash +# Interactive chat with local server +./build/gemma_api_client --interactive 1 --host localhost --port 8080 + +# Single prompt with local server +./build/gemma_api_client --prompt "Hello, how are you?" +``` + +#### With Public Google API +```bash +# Set API key and use public API +export GOOGLE_API_KEY="your-api-key-here" +./build/gemma_api_client --interactive 1 + +# Or pass API key directly +./build/gemma_api_client --api_key "your-api-key" --interactive 1 +``` + +## API Endpoints + +The server implements Google API endpoints: + +### 1. Generate Content - `POST /v1beta/models/gemma3-4b:generateContent` + +Generate a response for given content (non-streaming). + +**Request:** +```json +{ + "contents": [ + { + "parts": [ + {"text": "Why is the sky blue?"} + ] + } + ], + "generationConfig": { + "temperature": 0.9, + "topK": 1, + "maxOutputTokens": 1024 + } +} +``` + +**Response:** +```json +{ + "candidates": [ + { + "content": { + "parts": [ + {"text": "The sky appears blue because..."} + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "promptFeedback": { + "safetyRatings": [] + }, + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 25, + "totalTokenCount": 30 + } +} +``` + +### 2. Stream Generate Content - `POST /v1beta/models/gemma3-4b:streamGenerateContent` + +Generate a response with Server-Sent Events (SSE) streaming. + +**Request:** Same as above + +**Response:** Stream of SSE events: +``` +data: {"candidates":[{"content":{"parts":[{"text":"The"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}} + +data: {"candidates":[{"content":{"parts":[{"text":" sky"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}} + +data: [DONE] +``` + +### 3. List Models - `GET /v1beta/models` + +List available models. + +**Response:** +```json +{ + "models": [ + { + "name": "models/gemma3-4b", + "displayName": "Gemma3 4B", + "description": "Gemma3 4B model running locally" + } + ] +} +``` + +## Example Usage + +### Using curl with Local Server + +```bash +# Generate content (non-streaming) +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [{"parts": [{"text": "Hello, how are you?"}]}], + "generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024} + }' + +# Stream generate content (SSE) +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:streamGenerateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [{"parts": [{"text": "Tell me a story"}]}], + "generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024} + }' + +# List models +curl http://localhost:8080/v1beta/models +``` + +### Multi-turn Conversation with curl + +```bash +# First message +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"parts": [{"text": "Hi, my name is Alice"}]} + ] + }' + +# Follow-up message with conversation history +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"parts": [{"text": "Hi, my name is Alice"}]}, + {"parts": [{"text": "Hello Alice! Nice to meet you."}]}, + {"parts": [{"text": "What is my name?"}]} + ] + }' +``` + +### Using Python + +```python +import requests + +# Generate content +response = requests.post('http://localhost:8080/v1beta/models/gemma3-4b:generateContent', + json={ + 'contents': [{'parts': [{'text': 'Explain quantum computing in simple terms'}]}], + 'generationConfig': { + 'temperature': 0.9, + 'topK': 1, + 'maxOutputTokens': 1024 + } + } +) + +result = response.json() +if 'candidates' in result and result['candidates']: + text = result['candidates'][0]['content']['parts'][0]['text'] + print(text) +``` + +## Configuration Options + +The Google API supports various generation configuration options: + +- **temperature**: Controls randomness (0.0 to 2.0, default: 1.0) +- **topK**: Top-K sampling parameter (default: 1) +- **maxOutputTokens**: Maximum number of tokens to generate (default: 8192) + +## Key Features + +- **Unified Implementation**: Same codebase handles both local server and public API +- **Session Management**: Maintains conversation context using KV cache +- **Streaming Support**: Real-time token generation via Server-Sent Events +- **Error Handling**: Comprehensive error responses and HTTP status codes +- **Memory Efficient**: Optimized token processing and caching + +## Compatibility + +This implementation is compatible with: +- Google API format and endpoints +- Standard HTTP clients (curl, browsers, Python requests, etc.) +- Server-Sent Events (SSE) for streaming responses +- JSON request/response format diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e4bfd0..8309840 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,28 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(json) +# Find OpenSSL for HTTPS support +find_package(OpenSSL) +if(OPENSSL_FOUND) + message(STATUS "OpenSSL found, enabling HTTPS support") + set(HTTPLIB_USE_OPENSSL_IF_AVAILABLE ON) +else() + message(STATUS "OpenSSL not found, HTTPS support disabled") + set(HTTPLIB_USE_OPENSSL_IF_AVAILABLE OFF) +endif() + +# HTTP library for API server +FetchContent_Declare(httplib GIT_REPOSITORY https://github.com/yhirose/cpp-httplib.git GIT_TAG v0.18.1 EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(httplib) + +# Create interface target for httplib (header-only library) +add_library(httplib_interface INTERFACE) +target_include_directories(httplib_interface INTERFACE ${httplib_SOURCE_DIR}) +if(OPENSSL_FOUND) + target_link_libraries(httplib_interface INTERFACE OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(httplib_interface INTERFACE CPPHTTPLIB_OPENSSL_SUPPORT) +endif() + set(BENCHMARK_ENABLE_TESTING OFF) set(BENCHMARK_ENABLE_GTEST_TESTS OFF) @@ -232,3 +254,12 @@ endif() # GEMMA_ENABLE_TESTS add_executable(migrate_weights io/migrate_weights.cc) target_link_libraries(migrate_weights libgemma hwy hwy_contrib) + + +# API server with SSE support +add_executable(gemma_api_server gemma/api_server.cc) +target_link_libraries(gemma_api_server libgemma hwy hwy_contrib nlohmann_json::nlohmann_json httplib_interface) + +# API client for testing +add_executable(gemma_api_client gemma/api_client.cc) +target_link_libraries(gemma_api_client libgemma hwy hwy_contrib nlohmann_json::nlohmann_json httplib_interface) diff --git a/gemma/api_client.cc b/gemma/api_client.cc new file mode 100644 index 0000000..1f64d96 --- /dev/null +++ b/gemma/api_client.cc @@ -0,0 +1,359 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test client for API server + +#include +#include +#include +#include +#include + +#include "httplib.h" +#include "nlohmann/json.hpp" +#include "gemma/gemma_args.h" + +using json = nlohmann::json; + +// ANSI color codes +const std::string RESET = "\033[0m"; +const std::string BOLD = "\033[1m"; +const std::string GREEN = "\033[32m"; +const std::string BLUE = "\033[34m"; +const std::string CYAN = "\033[36m"; +const std::string YELLOW = "\033[33m"; +const std::string RED = "\033[31m"; + +class APIClient { +public: + APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b") + : host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) { + if (use_https_) { + ssl_client_ = std::make_unique(host, port); + ssl_client_->set_read_timeout(60, 0); + ssl_client_->set_write_timeout(60, 0); + ssl_client_->enable_server_certificate_verification(false); + } else { + client_ = std::make_unique(host, port); + client_->set_read_timeout(60, 0); + client_->set_write_timeout(60, 0); + } + } + + // Unified request processing for both public and local APIs + json ProcessRequest(const json& request, bool stream = true) { + bool is_public_api = !api_key_.empty(); + + std::string endpoint; + if (is_public_api) { + endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" + : "/v1beta/models/gemini-2.0-flash:generateContent"; + } else { + endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" + : "/v1beta/models/" + model_ + ":generateContent"; + } + + // Only show verbose output in non-interactive mode + if (!interactive_mode_) { + std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; + std::cout << "Request: " << request.dump(2) << std::endl; + } + + if (stream) { + return ProcessStreamingRequest(request, endpoint); + } else { + return ProcessNonStreamingRequest(request, endpoint); + } + } + + void TestGenerateContent(const std::string& prompt, bool stream = true) { + json request = CreateAPIRequest(prompt); + json response = ProcessRequest(request, stream); + + if (response.contains("error")) { + std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + } + } + + void TestListModels() { + std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; + + httplib::Headers headers; + if (!api_key_.empty()) { + headers.emplace("X-goog-api-key", api_key_); + } + auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers); + + if (res && res->status == 200) { + json response = json::parse(res->body); + std::cout << GREEN << "✅ Available models:" << RESET << std::endl; + std::cout << response.dump(2) << std::endl; + } else { + std::cerr << RED << "❌ Request failed" << RESET << std::endl; + } + } + + void InteractiveChat() { + std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl; + std::cout << "Type ':gemma %q' to end.\n" << std::endl; + + interactive_mode_ = true; + json messages; + + while (true) { + std::cout << BOLD << BLUE << "You: " << RESET; + std::string input; + std::getline(std::cin, input); + + if (input == ":gemma %q") { + std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl; + break; + } + + if (input.empty()) continue; + + // Add user message with proper role + json user_message = {{"parts", {{{"text", input}}}}}; + if (!api_key_.empty()) { + user_message["role"] = "user"; + } + messages.push_back(user_message); + + // Create request using unified logic + json request = CreateAPIRequest("", messages); + + std::cout << BOLD << GREEN << "Assistant: " << RESET; + + // Use unified processing - streaming for real-time output + json response = ProcessRequest(request, true); + + if (response.contains("candidates") && !response["candidates"].empty()) { + auto& candidate = response["candidates"][0]; + if (candidate.contains("content") && candidate["content"].contains("parts")) { + for (const auto& part : candidate["content"]["parts"]) { + if (part.contains("text")) { + std::string assistant_response = part["text"].get(); + + // For streaming, the response is already displayed in real-time + // Just add to message history for context + json assistant_message = {{"parts", {{{"text", assistant_response}}}}}; + if (!api_key_.empty()) { + assistant_message["role"] = "model"; + } + messages.push_back(assistant_message); + } + } + } + } else if (response.contains("error")) { + std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + } + + std::cout << std::endl; + } + } + +private: + json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) { + json request = { + {"generationConfig", { + {"temperature", 0.9}, + {"topK", 1}, + {"maxOutputTokens", 1024} + }} + }; + + if (messages.empty()) { + // Single prompt + json user_message = {{"parts", {{{"text", prompt}}}}}; + if (!api_key_.empty()) { + user_message["role"] = "user"; + } + request["contents"] = json::array({user_message}); + } else { + // Use provided message history + request["contents"] = messages; + } + + return request; + } + + json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) { + httplib::Headers headers = {{"Content-Type", "application/json"}}; + if (!api_key_.empty()) { + headers.emplace("X-goog-api-key", api_key_); + } + + auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json") + : client_->Post(endpoint, headers, request.dump(), "application/json"); + + if (res && res->status == 200) { + json response = json::parse(res->body); + if (!interactive_mode_) { + std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl; + std::cout << response.dump(2) << std::endl; + } + return response; + } else { + json error_response = { + {"error", { + {"message", "Request failed"}, + {"status", res ? res->status : -1} + }} + }; + if (res && !res->body.empty()) { + error_response["error"]["details"] = res->body; + } + std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl; + return error_response; + } + } + + json ProcessStreamingRequest(const json& request, const std::string& endpoint) { + std::string accumulated_response; + + // Use same SSE logic for both public and local APIs + httplib::Request req; + req.method = "POST"; + req.path = endpoint; + req.set_header("Content-Type", "application/json"); + if (!api_key_.empty()) { + req.set_header("X-goog-api-key", api_key_); + } + req.body = request.dump(); + + req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool { + std::string chunk(data, data_length); + std::istringstream stream(chunk); + std::string line; + + while (std::getline(stream, line)) { + if (line.substr(0, 6) == "data: ") { + std::string event_data = line.substr(6); + + if (event_data == "[DONE]") { + if (!interactive_mode_) { + std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl; + } + } else { + try { + json event = json::parse(event_data); + if (event.contains("candidates") && !event["candidates"].empty()) { + auto& candidate = event["candidates"][0]; + if (candidate.contains("content") && candidate["content"].contains("parts")) { + for (const auto& part : candidate["content"]["parts"]) { + if (part.contains("text")) { + std::string text = part["text"].get(); + std::cout << text << std::flush; + accumulated_response += text; + } + } + } + } + } catch (const json::exception& e) { + // Skip parse errors + } + } + } + } + return true; + }; + + httplib::Response res; + httplib::Error error; + bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error); + + if (res.status == 200 && !accumulated_response.empty()) { + return json{ + {"candidates", {{ + {"content", { + {"parts", {{{"text", accumulated_response}}}} + }} + }}} + }; + } else { + json error_response = { + {"error", { + {"message", "Streaming request failed"}, + {"status", res.status} + }} + }; + if (!res.body.empty()) { + error_response["error"]["details"] = res.body; + } + std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl; + return error_response; + } + } + +private: + std::unique_ptr client_; + std::unique_ptr ssl_client_; + std::string host_; + int port_; + std::string api_key_; + std::string model_; + bool use_https_; + bool interactive_mode_; +}; + +int main(int argc, char* argv[]) { + gcpp::ClientArgs client_args(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + std::cout << "\nAPI Client for gemma.cpp\n"; + std::cout << "========================\n\n"; + client_args.Help(); + std::cout << std::endl; + std::cout << "Environment Variables:" << std::endl; + std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl; + return 0; + } + + // Check for GOOGLE_API_KEY environment variable + const char* env_api_key = std::getenv("GOOGLE_API_KEY"); + if (env_api_key != nullptr && strlen(env_api_key) > 0) { + client_args.api_key = env_api_key; + client_args.host = "generativelanguage.googleapis.com"; + client_args.port = 443; + } + + // Handle API key override + if (!client_args.api_key.empty()) { + client_args.host = "generativelanguage.googleapis.com"; + client_args.port = 443; + } + + std::cout << BOLD << YELLOW << "🚀 Testing API Server at " + << client_args.host << ":" << client_args.port << RESET << std::endl; + + try { + APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model); + + if (client_args.interactive) { + client.InteractiveChat(); + } else { + client.TestListModels(); + client.TestGenerateContent(client_args.prompt, true); + } + + } catch (const std::exception& e) { + std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl; + std::cerr << "Make sure the API server is running:" << std::endl; + std::cerr << " ./build/gemma_api_server --tokenizer --weights " << std::endl; + return 1; + } + + return 0; +} diff --git a/gemma/api_server.cc b/gemma/api_server.cc new file mode 100644 index 0000000..70b3115 --- /dev/null +++ b/gemma/api_server.cc @@ -0,0 +1,514 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// HTTP API server for gemma.cpp with SSE support + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// HTTP server library +#undef CPPHTTPLIB_OPENSSL_SUPPORT +#undef CPPHTTPLIB_ZLIB_SUPPORT +#include "httplib.h" + +// JSON library +#include "nlohmann/json.hpp" + +#include "compression/types.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/tokenizer.h" +#include "ops/matmul.h" +#include "util/args.h" +#include "hwy/base.h" +#include "hwy/profiler.h" + +using json = nlohmann::json; + +namespace gcpp { + +static std::atomic server_running{true}; + +// Server state holding model and KV caches +struct ServerState { + std::unique_ptr gemma; + MatMulEnv* env; + ThreadingContext* ctx; + + // Session-based KV cache storage + struct Session { + std::unique_ptr kv_cache; + size_t abs_pos = 0; + std::chrono::steady_clock::time_point last_access; + }; + + std::unordered_map sessions; + std::mutex sessions_mutex; + std::mutex inference_mutex; + + // Cleanup old sessions after 30 minutes of inactivity + void CleanupOldSessions() { + std::lock_guard lock(sessions_mutex); + auto now = std::chrono::steady_clock::now(); + for (auto it = sessions.begin(); it != sessions.end();) { + if (now - it->second.last_access > std::chrono::minutes(30)) { + it = sessions.erase(it); + } else { + ++it; + } + } + } + + // Get or create session with KV cache + Session& GetOrCreateSession(const std::string& session_id) { + std::lock_guard lock(sessions_mutex); + auto& session = sessions[session_id]; + if (!session.kv_cache) { + session.kv_cache = std::make_unique(gemma->Config(), InferenceArgs(), env->ctx.allocator); + } + session.last_access = std::chrono::steady_clock::now(); + return session; + } +}; + +// Generate a unique session ID +std::string GenerateSessionId() { + static std::atomic counter{0}; + std::stringstream ss; + ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count() + << "_" << counter.fetch_add(1); + return ss.str(); +} + +// Wraps messages with start_of_turn markers - handles both with and without roles +std::string WrapMessagesWithTurnMarkers(const json& contents) { + std::string prompt; + + for (const auto& content : contents) { + if (content.contains("parts")) { + // Check if role is specified (public API format) or not (local format) + std::string role = content.value("role", ""); + + for (const auto& part : content["parts"]) { + if (part.contains("text")) { + std::string text = part["text"]; + + if (role == "user") { + prompt += "user\n" + text + "\nmodel\n"; + } else if (role == "model") { + prompt += text + "\n"; + } else if (role.empty()) { + // Local format without roles - for now, treat as user input + prompt += "user\n" + text + "\nmodel\n"; + } + } + } + } + } + + return prompt; +} + +// Parse generation config +RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) { + RuntimeConfig config; + config.gen = &gen; + config.verbosity = 0; + + // Set defaults matching public API + config.temperature = 1.0f; + config.top_k = 1; + config.max_generated_tokens = 8192; + + if (request.contains("generationConfig")) { + auto& gen_config = request["generationConfig"]; + + if (gen_config.contains("temperature")) { + config.temperature = gen_config["temperature"].get(); + } + if (gen_config.contains("topK")) { + config.top_k = gen_config["topK"].get(); + } + if (gen_config.contains("maxOutputTokens")) { + config.max_generated_tokens = gen_config["maxOutputTokens"].get(); + } + } + + return config; +} + +// Unified response formatter - creates consistent format regardless of request type +json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { + json response = { + {"candidates", {{ + {"content", { + {"parts", {{{"text", text}}}}, + {"role", "model"} + }}, + {"index", 0} + }}}, + {"promptFeedback", {{"safetyRatings", json::array()}}} + }; + + // Only add finishReason for non-streaming chunks + if (!is_streaming_chunk) { + response["candidates"][0]["finishReason"] = "STOP"; + } + + return response; +} + +// Handle generateContent endpoint (non-streaming) +void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { + try { + json request = json::parse(req.body); + + // Get or create session + std::string session_id = request.value("sessionId", GenerateSessionId()); + auto& session = state.GetOrCreateSession(session_id); + + // Extract prompt from API format + std::string prompt; + if (request.contains("contents")) { + prompt = WrapMessagesWithTurnMarkers(request["contents"]); + } else { + res.status = 400; + res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + return; + } + + // Lock for inference + std::lock_guard lock(state.inference_mutex); + + // Set up runtime config + std::mt19937 gen; + RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); + + // Collect full response + std::string full_response; + runtime_config.stream_token = [&full_response](int token, float) { + // Skip EOS token + return true; + }; + + // Tokenize prompt + std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), + state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, + session.abs_pos, + prompt); + + // Run inference with KV cache + TimingInfo timing_info = {.verbosity = 0}; + size_t prefix_end = 0; + + // Temporarily redirect output to capture response + std::stringstream output; + runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { + // Skip prompt tokens + if (session.abs_pos < tokens.size()) { + session.abs_pos++; + return true; + } + + session.abs_pos++; + + // Check for EOS + if (state.gemma->Config().IsEOS(token)) { + return true; + } + + // Decode token + std::string token_text; + state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); + output << token_text; + + return true; + }; + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, + *session.kv_cache, *state.env, timing_info); + + // Create response + json response = CreateAPIResponse(output.str(), false); + response["usageMetadata"] = { + {"promptTokenCount", tokens.size()}, + {"candidatesTokenCount", session.abs_pos - tokens.size()}, + {"totalTokenCount", session.abs_pos} + }; + + res.set_content(response.dump(), "application/json"); + + } catch (const json::exception& e) { + res.status = 400; + res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), + "application/json"); + } catch (const std::exception& e) { + res.status = 500; + res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(), + "application/json"); + } +} + +// Handle streamGenerateContent endpoint with SSE) +void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { + try { + json request = json::parse(req.body); + + // Get or create session + std::string session_id = request.value("sessionId", GenerateSessionId()); + auto& session = state.GetOrCreateSession(session_id); + + // Extract prompt from API format + std::string prompt; + if (request.contains("contents")) { + prompt = WrapMessagesWithTurnMarkers(request["contents"]); + } else { + res.status = 400; + res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + return; + } + + // Set up SSE headers + res.set_header("Content-Type", "text/event-stream"); + res.set_header("Cache-Control", "no-cache"); + res.set_header("Connection", "keep-alive"); + res.set_header("X-Session-Id", session_id); + + // Set up chunked content provider for SSE + res.set_chunked_content_provider( + "text/event-stream", + [&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) { + try { + // Lock for inference + std::lock_guard lock(state.inference_mutex); + auto& session = state.GetOrCreateSession(session_id); + + // Set up runtime config + std::mt19937 gen; + RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); + + // Tokenize prompt + std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), + state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, + session.abs_pos, + prompt); + + // Stream token callback + std::string accumulated_text; + auto stream_token = [&](int token, float) { + // Skip prompt tokens + if (session.abs_pos < tokens.size()) { + session.abs_pos++; + return true; + } + + session.abs_pos++; + + // Check for EOS + if (state.gemma->Config().IsEOS(token)) { + return true; + } + + // Decode token + std::string token_text; + state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); + accumulated_text += token_text; + + // Send SSE event using unified formatter + json event = CreateAPIResponse(token_text, true); + + std::string sse_data = "data: " + event.dump() + "\n\n"; + sink.write(sse_data.data(), sse_data.size()); + + return true; + }; + + runtime_config.stream_token = stream_token; + + // Run inference with KV cache + TimingInfo timing_info = {.verbosity = 0}; + size_t prefix_end = 0; + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, + *session.kv_cache, *state.env, timing_info); + + // Send final event using unified formatter + json final_event = CreateAPIResponse("", false); + final_event["usageMetadata"] = { + {"promptTokenCount", tokens.size()}, + {"candidatesTokenCount", session.abs_pos - tokens.size()}, + {"totalTokenCount", session.abs_pos} + }; + + std::string final_sse = "data: " + final_event.dump() + "\n\n"; + sink.write(final_sse.data(), final_sse.size()); + + // Send done event + sink.write("data: [DONE]\n\n", 15); + + // Ensure all data is sent + sink.done(); + + return false; // End streaming + + } catch (const std::exception& e) { + json error_event = {{"error", {{"message", e.what()}}}}; + std::string error_sse = "data: " + error_event.dump() + "\n\n"; + sink.write(error_sse.data(), error_sse.size()); + return false; + } + } + ); + + } catch (const json::exception& e) { + res.status = 400; + res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), + "application/json"); + } +} + +// Handle models list endpoint +void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { + json response = { + {"models", {{ + {"name", "models/" + inference.model}, + {"version", "001"}, + {"displayName", inference.model}, + {"description", inference.model + " model running locally"}, + {"inputTokenLimit", 8192}, + {"outputTokenLimit", 8192}, + {"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, + {"temperature", 1.0}, + {"topK", 1} + }}} + }; + + res.set_content(response.dump(), "application/json"); +} + +// void HandleShutdown(int signal) { +// std::cerr << "\nShutting down server..." << std::endl; +// server_running = false; +// } + +void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference) { + std::cerr << "Loading model..." << std::endl; + + // Initialize model + ThreadingContext ctx(threading); + MatMulEnv env(ctx); + + ServerState state; + state.gemma = std::make_unique(loader, inference, ctx); + state.env = &env; + state.ctx = &ctx; + + httplib::Server server; + + // Set up routes + server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { + res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); + }); + + // API endpoints + server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { + HandleListModels(state, inference, req, res); + }); + + std::string model_endpoint = "/v1beta/models/" + inference.model; + server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentNonStreaming(state, req, res); + }); + + server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentStreaming(state, req, res); + }); + + // Periodic cleanup of old sessions + std::thread cleanup_thread([&state]() { + while (server_running) { + std::this_thread::sleep_for(std::chrono::minutes(5)); + state.CleanupOldSessions(); + } + }); + + std::cerr << "Starting API server on port " << inference.port << std::endl; + std::cerr << "Model loaded successfully" << std::endl; + std::cerr << "Endpoints:" << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; + std::cerr << " GET /v1beta/models" << std::endl; + + if (!server.listen("0.0.0.0", inference.port)) { + std::cerr << "Failed to start server on port " << inference.port << std::endl; + } + + cleanup_thread.join(); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + gcpp::InternalInit(); + + gcpp::LoaderArgs loader(argc, argv); + gcpp::ThreadingArgs threading(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + std::cerr << "\n\nAPI server for gemma.cpp\n"; + std::cout << "========================\n\n"; + std::cerr << "Usage: " << argv[0] << " --weights --tokenizer [options]\n"; + std::cerr << "\nOptions:\n"; + std::cerr << " --port PORT Server port (default: 8080)\n"; + std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n"; + std::cerr << "\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Threading Arguments*\n\n"; + threading.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n"; + return 0; + } + + // Arguments are now handled by InferenceArgs + + // // Set up signal handler + // signal(SIGINT, gcpp::HandleShutdown); + // signal(SIGTERM, gcpp::HandleShutdown); + + gcpp::RunServer(loader, threading, inference); + + return 0; +} diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 161e9a5..469ba2a 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -183,6 +183,8 @@ struct InferenceArgs : public ArgsBase { bool multiturn; Path image_file; + int port; // Server port + std::string model; // Model name for API endpoints std::string prompt; // Bypasses std::getline // For prompts longer than the Linux terminal's 4K line edit buffer. Path prompt_file; @@ -218,6 +220,10 @@ struct InferenceArgs : public ArgsBase { "resets every turn)"); visitor(image_file, "image_file", Path(), "Image file to load."); + // Since it is not used in the CLI version, the print_verbosity is set higher than others. + visitor(port, "port", 8080, "Server port (default: 8080)", 3); + visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3); + visitor(prompt, "prompt", std::string(""), "Initial prompt for non-interactive mode. When specified, " "generates a response and exits.", @@ -258,6 +264,34 @@ struct InferenceArgs : public ArgsBase { } }; +struct ClientArgs : public ArgsBase { + ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + ClientArgs() { Init(); }; + + std::string host; + int port; + std::string api_key; + std::string model; + std::string prompt; + bool interactive; + + template + void ForEach(const Visitor& visitor) { + visitor(host, "host", std::string("localhost"), + "Server host (default: localhost)"); + visitor(port, "port", 8080, + "Server port (default: 8080)"); + visitor(api_key, "api_key", std::string(""), + "Use public API with key (changes host to generativelanguage.googleapis.com:443)"); + visitor(model, "model", std::string("gemma3-4b"), + "Model name to use (default: gemma3-4b)"); + visitor(prompt, "prompt", std::string("Hello! How are you?"), + "Prompt for generation (default: 'Hello! How are you?')"); + visitor(interactive, "interactive", false, + "Start interactive chat mode (0 = no, 1 = yes)"); + } +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ From 73f1140dca92b42a5a5cf620ad3b6d9c0c35155e Mon Sep 17 00:00:00 2001 From: Rhett Stucki Date: Wed, 20 Aug 2025 22:59:24 -0700 Subject: [PATCH 03/65] Fix an off-by-one error after StreamAndUpdateEOS() to remove the MSAN warning about reading an uninitialized variable in the kv_cache. The logic for choosing whether or not to attend to the last token during prefill wasn't completely consistent with StreamAndUpdateEOS(), causing an off-by-one error that prevented the kv_cache from being fully populated. PiperOrigin-RevId: 797614310 --- gemma/gemma.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a7b8423..95f52b8 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -520,6 +520,12 @@ static void GenerateT(const ModelConfig& config, const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, runtime_config, qbatch, non_eos); + // StreamAndUpdateEOS() sets the stream position one token too far in + // autoregressive mode. + const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); + if (!attend_to_last_token) { + qbatch.MutablePos(qi) -= 1; + } } size_t max_gen_steps = runtime_config.max_generated_tokens; From 9bf0fe4e375f6ef32189834f3a4c2df6fb74c3d8 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 26 Aug 2025 04:43:26 -0700 Subject: [PATCH 04/65] Internal change PiperOrigin-RevId: 799509375 --- gemma/configs.cc | 18 ++++++++++-------- gemma/configs.h | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/gemma/configs.cc b/gemma/configs.cc index 1f71545..f19d30d 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -710,8 +710,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const { ModelConfig a = *this; ModelConfig b = other; // Called by `OverwriteWithCanonical`, so ignore the fields it will set. - a.display_name = b.display_name; - a.model = b.model; + // Order matters: overwrite `b` with `a` because that is the known-good config + // when called by `OverwriteWithCanonical`. + b.display_name = a.display_name; + b.model = a.model; // The following are not yet set by config_converter.py, so we here ignore // them for purposes of comparison, and there overwrite the converter's config @@ -719,12 +721,12 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const { // these fields will be set. // `vit_config` is also not yet set, but we must not ignore it because // otherwise PaliGemma models will be indistinguishable for `configs_test`. - a.pool_dim = b.pool_dim; // ViT - a.eos_id = b.eos_id; - a.secondary_eos_id = b.secondary_eos_id; - a.scale_base_names = b.scale_base_names; - for (size_t i = 0; i < a.layer_configs.size(); ++i) { - a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating; + b.pool_dim = a.pool_dim; // ViT + b.eos_id = a.eos_id; + b.secondary_eos_id = a.secondary_eos_id; + b.scale_base_names = a.scale_base_names; + for (size_t i = 0; i < b.layer_configs.size(); ++i) { + b.layer_configs[i].optimized_gating = a.layer_configs[i].optimized_gating; } return AllEqual(a, b, print); diff --git a/gemma/configs.h b/gemma/configs.h index a3a3114..a1cd902 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -26,8 +26,8 @@ #include #include "compression/types.h" // Type -#include "io/fields.h" // IFieldsVisitor -#include "io/io.h" // Path +#include "io/fields.h" // IFieldsVisitor +#include "io/io.h" // Path #include "util/basics.h" namespace gcpp { From ed2f0bd1b0162109c5efe70edd77dbe2709ade7f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 26 Aug 2025 04:50:06 -0700 Subject: [PATCH 05/65] Fix pos assertions, refs #665 Ensure the streaming func pos matches the number of calls. Add two arguments that control pos+1 and pos+=1 behavior. Also cleanup/add comments. run: use batch_stream_func, add assert, higher verbosity for MM autotune output PiperOrigin-RevId: 799511163 --- evals/gemma_test.cc | 4 ++ gemma/gemma.cc | 115 ++++++++++++++++++++++---------------------- gemma/run.cc | 16 ++++-- 3 files changed, 73 insertions(+), 62 deletions(-) diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 12080f9..77efbae 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -138,6 +138,10 @@ TEST_F(GemmaTest, Multiturn) { // Reset the `response` string here, then check that the model actually has // access to the previous turn by asking to reproduce. response.clear(); + // -1 because our prefill does not generate KVs for the last token. Do not + // just pass abs_pos - 1 because our callback checks pos == abs_pos. + HWY_ASSERT(abs_pos > 0); + --abs_pos; model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), s_env->MutableEnv(), timing_info); fprintf(stderr, "decoded: '%s'\n", response.c_str()); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 95f52b8..b506e75 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -127,9 +127,10 @@ static float EmbeddingScaling(size_t model_dim) { hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); } -// `batch_idx` indicates which row of `x` to write to. -// `pos` is the *token*'s position, not the start of the batch, because this is -// called for batches of tokens in prefill, but batches of queries in decode. +// `x_row` indicates which row of `x` to write to. +// `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not +// the start of the batch, because this is called for batches of tokens in +// prefill, but batches of queries in decode. // // For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 // spec) until we run out of image tokens. This allows for a multi-image prompt @@ -137,7 +138,7 @@ static float EmbeddingScaling(size_t model_dim) { // calling application. // Returns new image_token_position. static HWY_NOINLINE size_t -EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, +EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const ModelConfig& model_config, const WeightsPtrs& weights, MatStorageT& x, ThreadingContext& ctx, const ImageTokens* image_tokens = nullptr, @@ -146,14 +147,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && image_token_position < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi), + hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row), x.Cols() * x.ElementBytes()); return image_token_position + 1; } if (model_config.wrapping == PromptWrapping::PALIGEMMA && image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi), + hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row), x.Cols() * x.ElementBytes()); return image_token_position; } @@ -174,14 +175,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const auto embedding_span = MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); const hn::ScalableTag df; - DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, + MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim, ctx.profiler, worker); }); if (model_config.absolute_pe) { - AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos); + AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos); } return image_token_position; } @@ -249,24 +250,12 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, for (size_t ti = 0; ti < tbatch_size; ++ti) { const size_t pos = qbatch_1.Pos(0) + ti; const size_t pos_in_prompt = tbatch_start + ti; + HWY_DASSERT(pos_in_prompt < prompt_size); const int token = qbatch_1.Prompt(0)[pos_in_prompt]; image_token_position = EmbedMMToken( token, ti, pos, pos_in_prompt, config, weights, activations.x, env.ctx, runtime_config.image_tokens, image_token_position); - } - - // Transformer with one batch of tokens from a single query. - for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); - ++layer_idx) { - TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), - activations, qbatch_1, env); - } - - // NOTE: we unconditionally call StreamToken, even if EOS. - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = qbatch_1.Pos(0) + ti; - const size_t pos_in_prompt = tbatch_start + ti; - const int token = qbatch_1.Prompt(0)[pos_in_prompt]; + // NOTE: we unconditionally call StreamToken, even if EOS. if (pos_in_prompt < prompt_size - 1) { runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f); } else { @@ -276,6 +265,14 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, } } + // Transformer with one batch of tokens from a single query. No need to + // set `PrevToken` because we already did the embedding above. + for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); + ++layer_idx) { + TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), + activations, qbatch_1, env); + } + qbatch_1.MutablePos(0) += tbatch_size; } // for tbatch_start if (attend_to_last_token) { @@ -291,8 +288,8 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, } // Embeds PrevToken (one from each query) and calls each TransformerLayer. -// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the -// token-batched `PrefillTBatch`. +// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the +// token-batched `PrefillTBatch`, which supports image embedding. static HWY_NOINLINE void Transformer(const ModelConfig& config, const RuntimeConfig& runtime_config, const WeightsPtrs& weights, @@ -324,8 +321,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, } } -// Populates KV cache for the batch queries, one token at a time. Only called -// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0. +// Populates KV cache for the batch queries, one token at a time. static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, const ModelConfig& config, const RuntimeConfig& runtime_config, @@ -337,6 +333,8 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, for (size_t qi = 0; qi < qbatch.Size(); ++qi) { non_eos.Set(qi); + + // Should only be called for autoregressive (non-prefix-LM) prefill. HWY_DASSERT(qbatch.PrefixEnd(qi) == 0); } @@ -358,7 +356,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, } // The input (PrevToken) is one token from each query in the batch. - // Do not call DecodeStepT because it computes logits for token + // Do not call `SampleAndStream` because it computes logits for token // probabilities, which are not required for the prompt tokens. Transformer(config, runtime_config, weights, activations, qbatch, env); } @@ -369,42 +367,40 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, } // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent -// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the +// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the // query is at the end of its sequence. static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, const ModelConfig& config, const RuntimeConfig& runtime_config, - QBatch& qbatch, hwy::BitSet4096<>& non_eos) { + QBatch& qbatch, bool pos_plus_1, bool update_pos, + hwy::BitSet4096<>& non_eos) { HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. - if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi), - qbatch.Pos(qi), token, prob))) { + const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0); + if (HWY_UNLIKELY( + !runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) { // User decided to stop: set token to primary EOS to trigger IsEOS below. token = config.eos_id; HWY_DASSERT(config.IsEOS(token)); } qbatch.PrevToken(qi) = token; - qbatch.MutablePos(qi) += 1; + qbatch.MutablePos(qi) += update_pos ? 1 : 0; // Primary or secondary EOS: mark query as EOS, but still increment (for // multi-turn, we should still keep the prior EOS). if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); } -// For a batch of queries, runs Transformer, computes logits, samples and -// streams the token. -static void DecodeStepT(const ModelConfig& config, - const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, - const SampleFunc& sample_token, - Activations& activations, QBatch& qbatch, - MatMulEnv& env, hwy::BitSet4096<>& non_eos, - TimingInfo& timing_info) { +// Must be called after Transformer: either after prefill, or during decode. +// Computes logits, samples and streams the token. +static void SampleAndStream( + const ModelConfig& config, const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, const SampleFunc& sample_token, + Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env, + hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); - Transformer(config, runtime_config, weights, activations, qbatch, env); - RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); if (HWY_UNLIKELY(runtime_config.activations_observer)) { @@ -427,8 +423,12 @@ static void DecodeStepT(const ModelConfig& config, const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); + // We streamed all prefill tokens, but pos is still one behind because we + // started generation at pos = prompt.size() - 1. We want the pos argument + // to match the number of calls to `StreamToken`, as expected by the caller. + const bool pos_plus_1 = true; StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, - non_eos); + pos_plus_1, update_pos, non_eos); }); } @@ -476,15 +476,16 @@ static void GenerateT(const ModelConfig& config, const size_t seq_len = qbatch.KV(0).SeqLen(); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const PromptTokens& prompt = qbatch.Prompt(qi); + // Sanity check: prompts should not be empty. Note that multi-turn prompts + // start with . + HWY_ASSERT(prompt.size() != 0); + max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); // Prefill stops before size - 1 because the last prompt token is the // first input token for generation. total_prefill_tokens += prompt.size() - 1; - // Sanity check: prompts should not be empty, nor start with EOS. - HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); - all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0; // We use a single divisor, so all sequence lengths must be the same. @@ -518,14 +519,12 @@ static void GenerateT(const ModelConfig& config, // Stream the last prompt token from each query, fill activations.gen_tokens. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); + const bool pos_plus_1 = false; // during prefill, pos is still correct. + // In autoregressive mode, we have not prefilled the last token, so do + // not advance. + const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, - runtime_config, qbatch, non_eos); - // StreamAndUpdateEOS() sets the stream position one token too far in - // autoregressive mode. - const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); - if (!attend_to_last_token) { - qbatch.MutablePos(qi) -= 1; - } + runtime_config, qbatch, pos_plus_1, update_pos, non_eos); } size_t max_gen_steps = runtime_config.max_generated_tokens; @@ -540,8 +539,10 @@ static void GenerateT(const ModelConfig& config, { timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { - DecodeStepT(config, runtime_config, weights, sample_token, activations, - qbatch, env, non_eos, timing_info); + Transformer(config, runtime_config, weights, activations, qbatch, env); + SampleAndStream(config, runtime_config, weights, sample_token, + activations, qbatch, /*update_pos=*/true, env, non_eos, + timing_info); } timing_info.NotifyGenerateDone(); } diff --git a/gemma/run.cc b/gemma/run.cc index 286c6ee..3915bf8 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -132,7 +132,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } // callback function invoked for each generated token. - auto stream_token = [&](int token, float) { + auto batch_stream_token = [&](size_t query_idx, size_t pos, int token, + float) { + std::string token_text; + HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); + + HWY_ASSERT(pos == abs_pos); ++abs_pos; const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; @@ -148,8 +153,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } return true; } - std::string token_text; - HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); if (first_response_token) { token_text.erase(0, token_text.find_first_not_of(" \t\n")); if (inference.verbosity >= 1) { @@ -187,7 +190,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, .verbosity = inference.verbosity, - .stream_token = stream_token, + .batch_stream_token = batch_stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); std::vector prompt; @@ -223,6 +226,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, if (inference.verbosity >= 1) { std::cerr << "\n[ Reading prompt ] " << std::flush; } + // -1 because our prefill does not generate KVs for the last token. Do not + // just pass abs_pos - 1 because our callback checks pos == abs_pos. + if (abs_pos > 0) --abs_pos; gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env, timing_info); std::cout << "\n\n"; @@ -255,7 +261,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, ThreadingContext ctx(threading); MatMulEnv env(ctx); - if (inference.verbosity >= 2) env.print_best = true; + if (inference.verbosity >= 3) env.print_best = true; const Gemma gemma(loader, inference, ctx); KVCache kv_cache(gemma.Config(), inference, ctx.allocator); From 86afd530760e76351bf2f65e35d6284122fc29e4 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 26 Aug 2025 11:54:48 -0700 Subject: [PATCH 06/65] 1.04x speedup: Parallelize SoftCap Also require opt-in constexpr flag for observer callbacks, update zones PiperOrigin-RevId: 799655163 --- gemma/gemma.cc | 49 +++++++++++++++++++++++++++++++++---------------- ops/ops-inl.h | 12 ++++++++++++ 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b506e75..bef3a70 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -61,6 +61,9 @@ #include "hwy/base.h" #include "hwy/timer.h" +// Require opt-in to debug/introspection functions to eliminate their overhead. +HWY_INLINE_VAR constexpr bool kObserver = false; + #endif // GEMMA_CC_ONCE HWY_BEFORE_NAMESPACE(); @@ -143,6 +146,10 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, MatStorageT& x, ThreadingContext& ctx, const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { + static const auto zone = + ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone); + // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && @@ -295,11 +302,13 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env) { - if (HWY_UNLIKELY(runtime_config.layers_output)) { - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - const float token_f = qbatch.PrevToken(qi); - runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), - "tokens", -1, &token_f, 1); + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.layers_output)) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const float token_f = qbatch.PrevToken(qi); + runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), + "tokens", -1, &token_f, 1); + } } } @@ -313,10 +322,12 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), activations, qbatch, env); - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, - activations); + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, + activations); + } } } } @@ -403,23 +414,29 @@ static void SampleAndStream( RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); + } } { - PROFILER_ZONE("Gen.EmbeddingMatmul"); + static const auto zone = env.ctx.profiler.AddZone( + "Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone); // Compute logits from last layer activations. CallMatMul(activations.x, weights.embedder_input_embedding, /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - const size_t worker = 0; // TODO: parallelize + + MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos, + env.ctx); + + // TODO: parallelize non_eos.Foreach([&](size_t qi) { float* HWY_RESTRICT logits = activations.logits.Row(qi); - MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, - env.ctx.profiler, worker); const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 343600a..95228f4 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -33,6 +33,7 @@ #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" +#include "hwy/bit_set.h" #include "hwy/contrib/sort/order.h" #include "hwy/contrib/sort/vqsort.h" #include "hwy/detect_targets.h" @@ -932,6 +933,17 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( } } +static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( + const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, + ThreadingContext& ctx) { + if (cap == 0.0f) return; + SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { + if (non_eos.Get(task)) { + LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); + } + }); +} + static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(const float* probabilities, size_t vocab_size) { size_t max_index = 0; From 5411fd846de78d7d5d0d6f846252308d337ccf46 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 26 Aug 2025 23:32:43 -0700 Subject: [PATCH 07/65] Minor: batched NotifyGenerate, fix comment/dep PiperOrigin-RevId: 799889802 --- BUILD.bazel | 1 + gemma/gemma.cc | 3 ++- gemma/gemma.h | 9 +++++---- ops/ops-inl.h | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 6429f6d..1f9e210 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -526,6 +526,7 @@ cc_library( "//io", "//paligemma:image", "@highway//:hwy", + "@highway//hwy/contrib/sort:vqsort", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bef3a70..c6da86c 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -434,11 +434,12 @@ static void SampleAndStream( MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos, env.ctx); + timing_info.NotifyGenerated(non_eos.Count()); + // TODO: parallelize non_eos.Foreach([&](size_t qi) { float* HWY_RESTRICT logits = activations.logits.Row(qi); const TokenAndProb tp = sample_token(logits, config.vocab_size); - timing_info.NotifyGenerated(); // We streamed all prefill tokens, but pos is still one behind because we // started generation at pos = prompt.size() - 1. We want the pos argument diff --git a/gemma/gemma.h b/gemma/gemma.h index 5ebd70d..55be003 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -177,9 +177,10 @@ struct TimingInfo { // be sure to populate prefill_start and generate_start before calling // NotifyGenerated. - void NotifyGenerated() { - ++tokens_generated; - if (HWY_UNLIKELY(tokens_generated == 1)) { + void NotifyGenerated(size_t batch_size) { + const bool is_first = (tokens_generated == 0); + tokens_generated += batch_size; + if (HWY_UNLIKELY(is_first)) { time_to_first_token = hwy::platform::Now() - prefill_start; if (verbosity >= 1) { double prefill_tok_sec = @@ -191,7 +192,7 @@ struct TimingInfo { prefill_tok_sec, static_cast(time_to_first_token * 1000)); } } - if (verbosity >= 2 && tokens_generated % 128 == 0) { + if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) { double gen_tok_sec = static_cast(tokens_generated) / (hwy::platform::Now() - generate_start); fprintf(stderr, diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 95228f4..8a7224b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -679,7 +679,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, } } -// Same as above, but without a separate output. Same as below without the add. +// Same as above, but with a separate output. Same as below without the add. template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, From 6128e758ff673658ccefbe59f6a247efc389d54b Mon Sep 17 00:00:00 2001 From: Marie White Date: Thu, 28 Aug 2025 00:01:01 -0700 Subject: [PATCH 08/65] Change ffw_out from B16 to F32. PiperOrigin-RevId: 800330411 --- gemma/activations.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/activations.h b/gemma/activations.h index b222bd9..877afdf 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -205,7 +205,7 @@ struct Activations { // Norm may be large, so prefer to keep as f32. MatStorageT C1; MatStorageT C2; - MatStorageT ffw_out; + MatStorageT ffw_out; AttentionActivations attention; GriffinActivations griffin; From 98ddc166dbd107a219fa30a861a41cc1eef0381e Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 28 Aug 2025 08:31:25 -0700 Subject: [PATCH 09/65] Expand ThreadingContext comments PiperOrigin-RevId: 800479954 --- gemma/gemma.h | 8 +++++--- util/threading.h | 4 +++- util/threading_context.h | 30 +++++++++++++++++++++++++----- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/gemma/gemma.h b/gemma/gemma.h index 55be003..0f9aae2 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -227,12 +227,12 @@ struct TimingInfo { }; // After construction, all methods are const and thread-compatible if using -// separate ThreadingContext for each thread. +// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`. class Gemma { public: // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. - // `ctx` is only used to read tensors, but it is typically also referenced - // by the `MatMulEnv` passed to the Generate* methods. + // `ctx` is only used to read tensors and not stored. Calls to `Generate*` + // may reference the same, or other `ThreadingContext` via `MatMulEnv`. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, ThreadingContext& ctx); ~Gemma(); @@ -248,6 +248,8 @@ class Gemma { // `pos` is the position in the KV cache. Users are responsible for // incrementing it in the `*StreamFunc`, or setting to zero for single-turn. + // All `Generate*` may be called concurrently if `env` and the + // `ThreadingContext` it references are both distinct. void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) const { diff --git a/util/threading.h b/util/threading.h index 6c2e187..0a57ddb 100644 --- a/util/threading.h +++ b/util/threading.h @@ -68,7 +68,9 @@ class NestedPools { // `max_threads` is the maximum number of threads to divide among all // clusters. This is more intuitive than a per-cluster limit for users who - // may not be aware of the CPU topology. 0 means no limit. + // may not be aware of the CPU topology. This should be zero (meaning no + // further limits) if the caller has already set limits via `skip_*` or + // `max_*` args passed to `ThreadingContext`. // // To ensure we do not create more threads than there are HW cores, which // would cause huge slowdowns when spinning, the `BoundedSlice` arguments diff --git a/util/threading_context.h b/util/threading_context.h index 08387d0..d4fdc17 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -28,6 +28,7 @@ #include "util/basics.h" // Tristate, kMaxPackages #include "util/threading.h" #include "util/topology.h" +#include "hwy/profiler.h" // IWYU pragma: end_exports namespace gcpp { @@ -55,9 +56,10 @@ class ThreadingArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { - // These can be used to partition CPU sockets/packages and their + // These can be used to partition CPU packages/sockets and their // clusters/CCXs across several program instances. The default is to use - // all available resources. + // all available resources on one package. Note that `kMaxPackages` is an + // upper bound on `max_packages`. visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); visitor(max_packages, "max_packages", size_t{1}, @@ -67,15 +69,18 @@ class ThreadingArgs : public ArgsBase { "Index of the first CCX to use; default 0 = unlimited.", 2); visitor(max_clusters, "max_clusters", size_t{0}, "Max CCXs to use; default 0 = unlimited.", 2); - // These are only used when CPU topology is unknown. + // "Logical processors" (LPs). These are used when CPU topology is unknown. visitor(skip_lps, "skip_lps", size_t{0}, "Index of the first LP to use; default 0 = unlimited.", 2); visitor(max_lps, "max_lps", size_t{0}, "Max LPs to use; default 0 = unlimited.", 2); - // The exact meaning is more subtle: see the comment at NestedPools ctor. + // DEPRECATED: superseded by the above fields. If nonzero, `NestedPools` + // will attempt to create this many threads distributed over the detected + // topology. visitor(max_threads, "num_threads", size_t{0}, "Max threads to use; default 0 = unlimited.", 2); + visitor(pin, "pin", Tristate::kDefault, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); visitor(spin, "spin", Tristate::kDefault, @@ -86,13 +91,28 @@ class ThreadingArgs : public ArgsBase { } }; +// Owns threads corresponding to a subset of the system's resources. Because +// this is passed to `Gemma::Generate` (via `MatMulEnv`) rather than defined as +// a singleton, we can have multiple concurrent `Generate` calls within the +// same process, each with their own `ThreadingContext`. Because each context +// may pin its threads, it is important that they use distinct packages, +// clusters, or LPs. For example, to use two packages, the first `args` can have +// `skip_packages` = 0 and the second `skip_packages` = 1. struct ThreadingContext { - // Expected to be called early in the program, before threading starts. explicit ThreadingContext(const ThreadingArgs& args); + // Singleton; pass around a reference to reduce overhead. hwy::Profiler& profiler; + + // Detects topology, subject to limits imposed by user-specified `args`. + // For example, if `args.max_packages` is 1, then `topology.NumPackages()` + // will be 1 regardless of the actual system topology. BoundedTopology topology; + + // Ctor depends on `topology` for deciding whether to enable NUMA. Allocator allocator; + + // Per-package/cluster/within cluster pools of threads, matching `topology`. NestedPools pools; }; From 31c09cca4cdc0aaea8acc98175bb7b6b5154a852 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 28 Aug 2025 08:55:15 -0700 Subject: [PATCH 10/65] f32 LoopKC: 1.37x(M=512), 1.19(M=128) single-K F32,BF16 matmul speedup on SKX Add a special case for A=F32,B=BF16, used when there is no native bf16 dot product. dot-inl: ensure bf16,f32 and f32,bf16 both get promoted to float before f64 summation matmul.cc: update autotuning to reflect actual A size matmul_test: add all combinations of bf16/f32, report all results, not just first difference, check non-vector-aligned K PiperOrigin-RevId: 800487817 --- compression/test_util-inl.h | 17 +- ops/dot-inl.h | 61 +++- ops/matmul-inl.h | 633 ++++++++++++++++++++++++++---------- ops/matmul.cc | 24 +- ops/matmul.h | 8 +- ops/matmul_test.cc | 51 ++- 6 files changed, 594 insertions(+), 200 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 7c4f854..207b225 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -85,15 +85,21 @@ MatStorageT GenerateMat(const Extents2D& extents, row[c] = f; } Compress(raw.Row(r), raw.Cols(), ws.tls[thread], - MakeSpan(compressed.Row(r), compressed.Cols()), + MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); + + // MatMul requires that A's padding be zero-initialized. + hwy::ZeroBytes( + compressed.Row(r) + extents.cols, + (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); compressed.SetScale(0.6f); // Arbitrary value, different from 1. return compressed; } -// Same, but `extents` describes the transposed matrix. +// Same, but `extents` describes the transposed matrix and the computation of +// `f` swaps `r` and `c`. template MatStorageT GenerateTransposedMat(const Extents2D extents, const Allocator& allocator, @@ -112,8 +118,13 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, row[c] = f; } Compress(raw.Row(r), raw.Cols(), ws.tls[thread], - MakeSpan(compressed.Row(r), compressed.Cols()), + MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); + + // MatMul requires that B's padding be zero-initialized. + hwy::ZeroBytes( + compressed.Row(r) + extents.cols, + (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); // Arbitrary value, different from 1, must match `GenerateMat`. diff --git a/ops/dot-inl.h b/ops/dot-inl.h index 48aaae9..dae2106 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -157,15 +157,16 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) { // promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA. struct DotKernelDouble { // Only `CompressTraits` can `Decompress2` to `double`, so both have - // to be `float` in order to have `Raw = double`. Note that if either type is - // smaller than `float`, we may demote the other type from `float` to `BF16`. + // to be `float` in order to have `Raw = double`. To avoid loss of accuracy, + // if either is float, we decompress both to float, otherwise `BF16`. template - using Raw = hwy::If() && IsF32(), double, BF16>; + using Raw = hwy::If() && IsF32(), double, + hwy::If() || IsF32(), float, BF16>>; using State = double; // Raw = double template , HWY_IF_F64_D(DRaw)> - HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2, + HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, const VR w3, const VR v0, const VR v1, const VR v2, const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const { @@ -175,6 +176,41 @@ struct DotKernelDouble { sum3 = hn::MulAdd(w3, v3, sum3); } + // Raw = float + template , HWY_IF_F32_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, + const VR w3, const VR v0, const VR v1, const VR v2, + const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS&, VS&, VS&, VS&) const { + const hn::Repartition dd; + using VD = hn::Vec; + VD w0d = hn::PromoteLowerTo(dd, w0); + VD w1d = hn::PromoteLowerTo(dd, w1); + VD w2d = hn::PromoteLowerTo(dd, w2); + VD w3d = hn::PromoteLowerTo(dd, w3); + VD v0d = hn::PromoteLowerTo(dd, v0); + VD v1d = hn::PromoteLowerTo(dd, v1); + VD v2d = hn::PromoteLowerTo(dd, v2); + VD v3d = hn::PromoteLowerTo(dd, v3); + sum0 = hn::MulAdd(w0d, v0d, sum0); + sum1 = hn::MulAdd(w1d, v1d, sum1); + sum2 = hn::MulAdd(w2d, v2d, sum2); + sum3 = hn::MulAdd(w3d, v3d, sum3); + w0d = hn::PromoteUpperTo(dd, w0); + w1d = hn::PromoteUpperTo(dd, w1); + w2d = hn::PromoteUpperTo(dd, w2); + w3d = hn::PromoteUpperTo(dd, w3); + v0d = hn::PromoteUpperTo(dd, v0); + v1d = hn::PromoteUpperTo(dd, v1); + v2d = hn::PromoteUpperTo(dd, v2); + v3d = hn::PromoteUpperTo(dd, v3); + sum0 = hn::MulAdd(w0d, v0d, sum0); + sum1 = hn::MulAdd(w1d, v1d, sum1); + sum2 = hn::MulAdd(w2d, v2d, sum2); + sum3 = hn::MulAdd(w3d, v3d, sum3); + } + // Raw = BF16 template , HWY_IF_BF16_D(DRaw), class DS = hn::Repartition, class VS = hn::Vec> @@ -217,11 +253,26 @@ struct DotKernelDouble { // Raw = double template , HWY_IF_F64_D(DRaw)> - HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0, + HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VR& sum0, VR&) const { sum0 = hn::MulAdd(w0, v0, sum0); } + // Raw = float + template , HWY_IF_F32_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VS& sum0, + VS&) const { + const hn::Repartition dd; + using VD = hn::Vec; + VD w0d = hn::PromoteLowerTo(dd, w0); + VD v0d = hn::PromoteLowerTo(dd, v0); + sum0 = hn::MulAdd(w0d, v0d, sum0); + w0d = hn::PromoteUpperTo(dd, w0); + v0d = hn::PromoteUpperTo(dd, v0); + sum0 = hn::MulAdd(w0d, v0d, sum0); + } + // Raw = BF16 template , HWY_IF_BF16_D(DRaw), class DS = hn::Repartition, class VS = hn::Vec> diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 4741759..9f279cb 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -113,7 +113,7 @@ class MMStoreHorizontalSumsIntoC { const size_t row_c, const size_t col_c, const MMArgs& args, RowPtrs C_rows) const { HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; - const size_t N = hn::Lanes(df); + HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing // log(N) operations for vectors of length N. Because `kNR` == 4, we // instead use `StoreInterleaved4` for a vector length-agnostic @@ -230,7 +230,7 @@ class MMAddHorizontalSumsIntoPartial { const hn::Repartition dd; HWY_ALIGN double buf[16 * hn::MaxLanes(dd)]; using VD = hn::Vec; - const size_t ND = hn::Lanes(dd); + HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); VD C00 = SumOfPromotedPairs(dd, F00); VD C01 = SumOfPromotedPairs(dd, F01); VD C02 = SumOfPromotedPairs(dd, F02); @@ -351,8 +351,8 @@ class MMKernel { // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template - static HWY_INLINE void A2C0(const StridedViewBF& A_view, + template + static HWY_INLINE void A2C0(const StridedView A_view, const bool A_padded, const StridedViewBF& B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, @@ -365,8 +365,8 @@ class MMKernel { // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } return; } @@ -375,13 +375,13 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } return; } @@ -389,18 +389,20 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); imc += 1; } HWY_DASSERT(imc == mc); @@ -423,9 +425,10 @@ class MMKernel { // `MMAddHorizontalSumsIntoPartial`. template , class DF = hn::Repartition, class VF = hn::Vec> - static HWY_INLINE void ElementwiseMulAcc(DBF dbf, VBF a, VBF b0, VBF b1, - VBF b2, VBF b3, VF& C0, VF& C1, - VF& C2, VF& C3) { + static HWY_INLINE void ElementwiseMulAccNativeBF(DBF dbf, VBF a, VBF b0, + VBF b1, VBF b2, VBF b3, + VF& C0, VF& C1, VF& C2, + VF& C3) { // This handles a single row of A, so the horizontal sums of `C0..3` are the // (partial) dot products for 4 consecutive values in one row of C. static_assert(kNR == 4); @@ -443,16 +446,17 @@ class MMKernel { HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); } - // Like `ElementwiseMulAcc`, but splits BF16 inputs into odd and even f32 - // for use with FMA. Also handles two rows at a time to hide the FMA latency - // (we assume 4 cycles and dual-issue) before writing `C00` again. + // Like `ElementwiseMulAccNativeBF`, but splits BF16 inputs into odd and even + // f32 for use with FMA. Also handles two rows at a time to hide the FMA + // latency (we assume 4 cycles and dual-issue) before writing `C00` again. template , class DF = hn::Repartition, class VF = hn::Vec> - static HWY_INLINE void ElementwiseMulAcc2(DBF dbf, VBF a0, VBF a1, VF b0o, - VF b0e, VF b1o, VF b1e, VF b2o, - VF b2e, VF b3o, VF b3e, VF& C00, - VF& C01, VF& C02, VF& C03, VF& C10, - VF& C11, VF& C12, VF& C13) { + static HWY_INLINE void ElementwiseMulAccEmuBF(DBF dbf, VBF a0, VBF a1, VF b0o, + VF b0e, VF b1o, VF b1e, VF b2o, + VF b2e, VF b3o, VF b3e, VF& C00, + VF& C01, VF& C02, VF& C03, + VF& C10, VF& C11, VF& C12, + VF& C13) { const DF df; HWY_DASSERT(!HWY_NATIVE_DOT_BF16); // Avoid `ReorderWidenMulAccumulate` because it requires extra adds for @@ -491,20 +495,36 @@ class MMKernel { } } - // Innermost loop over `kc` columns (typically 1024-4096) in steps of one - // vector, for `kRowsAC` rows of `A_view` from range_mc-relative `imc` and - // `B_view` from row 0 (both at column 0). Updates a `kRowsAC x kNR` tile - // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be - // BF16 so we can load directly without `Decompress2`, which is expensive for - // NUQ and requires 2x unrolling, which requires more loads. - template - static HWY_INLINE void LoopKC(const StridedViewBF& A_view, + // For A=F32, B=BF16 without native BF16 dot product: one lane-crossing + // promotion is likely cheaper than AND+SHIFT for promoting odd/even BF. + // Caller already promoted B, so all inputs are F32. + template , HWY_IF_F32_D(DF)> + static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2, + VF b3, VF& C0, VF& C1, VF& C2, + VF& C3) { + HWY_DASSERT(!HWY_NATIVE_DOT_BF16); + C0 = hn::MulAdd(a, b0, C0); + C1 = hn::MulAdd(a, b1, C1); + C2 = hn::MulAdd(a, b2, C2); + C3 = hn::MulAdd(a, b3, C3); + } + + // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a + // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` + // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). + // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. + // `B` is BF16, `A` and `C` can be F32 or BF16. + template + static HWY_INLINE void LoopKC(const StridedView A_view, + const bool A_padded, const StridedViewBF& B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; - const size_t NBF = hn::Lanes(dbf); + + HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag()); + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_DASSERT(kRowsAC <= kMaxMR); HWY_DASSERT(col_c % kNR == 0); @@ -512,30 +532,36 @@ class MMKernel { // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. static_assert(kNR == 4); - const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0); - const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; - const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; - const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; + const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0); + const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; + const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; + const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; const BF16* HWY_RESTRICT br0 = B_view.Row(0); const BF16* HWY_RESTRICT br1 = B_view.Row(1); const BF16* HWY_RESTRICT br2 = B_view.Row(2); const BF16* HWY_RESTRICT br3 = B_view.Row(3); - // Ensure `A` and `B` were zero-padded by `DecompressAndZeroPad`. + // Ensure `A` and `B` were zero-padded. if constexpr (HWY_IS_DEBUG_BUILD) { + // Only check if `A` is padded, i.e. not packed. + if (A_padded) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) { + { + HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); + } + if constexpr (kRowsAC > 1) { + HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); + } + if constexpr (kRowsAC > 2) { + HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); + } + if constexpr (kRowsAC > 3) { + HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); + } + } + } + // B is unconditionally zero-padded by `DecompressAndZeroPad`. for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - { - HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); - } - if constexpr (kRowsAC > 1) { - HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); - } - if constexpr (kRowsAC > 2) { - HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); - } - if constexpr (kRowsAC > 3) { - HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); - } HWY_DASSERT(hwy::ConvertScalarTo(br0[i]) == 0.0f); HWY_DASSERT(hwy::ConvertScalarTo(br1[i]) == 0.0f); HWY_DASSERT(hwy::ConvertScalarTo(br2[i]) == 0.0f); @@ -553,60 +579,287 @@ class MMKernel { C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df), C33 = hn::Zero(df); - HWY_UNROLL(1) - for (size_t ikc = 0; ikc < kc; ikc += NBF) { + size_t ikc = 0; + // The loop step is always NBF: for non-native BF16 with TA=F32, this + // entails 2x unrolling, which helps a little. + const HWY_LANES_CONSTEXPR size_t kc_step = NBF; + // If A is packed (not padded), we have to check for remainders. Otherwise, + // we only run the main loop because A's padding is zero-initialized by + // `ZeroInit` or weights.cc. + const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc; + if (kc_end >= kc_step) { + HWY_UNROLL(1) + for (; ikc <= kc_end - kc_step; ikc += kc_step) { + if constexpr (HWY_NATIVE_DOT_BF16) { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + + // Should only get here if `A` is BF16, otherwise `DecompressA` would + // convert to BF16 and `A_view` points to that. + HWY_DASSERT(IsBF16()); + + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VBF a1 = hn::Load(dbf, ar1 + ikc); + ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12, + C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VBF a3 = hn::Load(dbf, ar3 + ikc); + ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32, + C33); + } + } else { // !HWY_NATIVE_DOT_BF16 + if constexpr (IsBF16()) { + // When both are BF16, it is better to load promote odd/even, + // because lane-crossing promotion for both might be bottlenecked on + // shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } + + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C00, C01, C02, C03, C10, C11, + C12, C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C20, C21, C22, C23, C30, C31, + C32, C33); + } + } else { // IsF32(): promote BF to 2xF32, F32*F32. + // Full-vector loads are a bit faster on SKX than half + PromoteTo. + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + const VF b00 = hn::PromoteLowerTo(df, b0); + const VF b10 = hn::PromoteLowerTo(df, b1); + const VF b20 = hn::PromoteLowerTo(df, b2); + const VF b30 = hn::PromoteLowerTo(df, b3); + const VF b01 = hn::PromoteUpperTo(df, b0); + const VF b11 = hn::PromoteUpperTo(df, b1); + const VF b21 = hn::PromoteUpperTo(df, b2); + const VF b31 = hn::PromoteUpperTo(df, b3); + + { + const VF a00 = hn::Load(df, ar0 + ikc); + ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a10 = hn::Load(df, ar1 + ikc); + ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, + C13); + } + + // C00 is ready again. On SKX, this interleaved unrolling is faster + // than consuming all `b*1` at the end of the loop. + { + const VF a01 = hn::Load(df, ar0 + ikc + NA); + ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a11 = hn::Load(df, ar1 + ikc + NA); + ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, + C13); + } + + if constexpr (kRowsAC > 2) { + const VF a20 = hn::Load(df, ar2 + ikc); + ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a30 = hn::Load(df, ar3 + ikc); + ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, + C33); + } + + if constexpr (kRowsAC > 2) { + const VF a21 = hn::Load(df, ar2 + ikc + NA); + ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a31 = hn::Load(df, ar3 + ikc + NA); + ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, + C33); + } + } + } + } + } + + // We want the number of actual valid kc, but we may already be beyond `kc`. + const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc; + HWY_DASSERT(remaining_kc < kc_step); + HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0)); + // Last iteration: B is padded but A is not; guard its loads. + if (HWY_UNLIKELY(remaining_kc != 0)) { if constexpr (HWY_NATIVE_DOT_BF16) { const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b3 = hn::Load(dbf, br3 + ikc); + + // Should only get here if `A` is BF16, otherwise `DecompressA` would + // convert to BF16 and `A_view` points to that. + HWY_DASSERT(IsBF16()); + { - const VBF a0 = hn::Load(dbf, ar0 + ikc); - ElementwiseMulAcc(dbf, a0, b0, b1, b2, b3, C00, C01, C02, C03); + const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, + C03); } if constexpr (kRowsAC > 1) { - const VBF a1 = hn::Load(dbf, ar1 + ikc); - ElementwiseMulAcc(dbf, a1, b0, b1, b2, b3, C10, C11, C12, C13); + const VBF a1 = hn::LoadN(dbf, ar1 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12, + C13); } if constexpr (kRowsAC > 2) { - const VBF a2 = hn::Load(dbf, ar2 + ikc); - ElementwiseMulAcc(dbf, a2, b0, b1, b2, b3, C20, C21, C22, C23); + const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22, + C23); } if constexpr (kRowsAC > 3) { - const VBF a3 = hn::Load(dbf, ar3 + ikc); - ElementwiseMulAcc(dbf, a3, b0, b1, b2, b3, C30, C31, C32, C33); + const VBF a3 = hn::LoadN(dbf, ar3 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32, + C33); } - } else { - VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; - { + } else { // !HWY_NATIVE_DOT_BF16 + if constexpr (IsBF16()) { + // When both are BF16, it is better to load promote odd/even, because + // lane-crossing promotion for both might be bottlenecked on shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } + + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); + const VBF a1 = + kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C00, C01, C02, C03, C10, C11, C12, + C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); + const VBF a3 = + kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C20, C21, C22, C23, C30, C31, C32, + C33); + } + } else { // IsF32(): promote half-B to F32, F32*F32. const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b3 = hn::Load(dbf, br3 + ikc); - b0e = hn::PromoteEvenTo(df, b0); - b1e = hn::PromoteEvenTo(df, b1); - b2e = hn::PromoteEvenTo(df, b2); - b3e = hn::PromoteEvenTo(df, b3); - b0o = FastPromoteOddTo(df, b0); - b1o = FastPromoteOddTo(df, b1); - b2o = FastPromoteOddTo(df, b2); - b3o = FastPromoteOddTo(df, b3); - } + const VF b00 = hn::PromoteLowerTo(df, b0); + const VF b10 = hn::PromoteLowerTo(df, b1); + const VF b20 = hn::PromoteLowerTo(df, b2); + const VF b30 = hn::PromoteLowerTo(df, b3); + const VF b01 = hn::PromoteUpperTo(df, b0); + const VF b11 = hn::PromoteUpperTo(df, b1); + const VF b21 = hn::PromoteUpperTo(df, b2); + const VF b31 = hn::PromoteUpperTo(df, b3); - { - const VBF a0 = hn::Load(dbf, ar0 + ikc); - const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; - ElementwiseMulAcc2(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o, - b3e, C00, C01, C02, C03, C10, C11, C12, C13); - } - if constexpr (kRowsAC > 2) { - const VBF a2 = hn::Load(dbf, ar2 + ikc); - const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; - ElementwiseMulAcc2(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o, - b3e, C20, C21, C22, C23, C30, C31, C32, C33); + const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA; + + { + const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, + C13); + } + + // C00 is ready again. On SKX, this interleaved unrolling is faster + // than consuming all `b*1` at the end of the loop. + { + const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, + C13); + } + + if constexpr (kRowsAC > 2) { + const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, + C33); + } + + if constexpr (kRowsAC > 2) { + const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, + C33); + } } } - } + } // remaining_kc != 0 // This is a substantial fraction (about 1/3) of the total time, but is // called frequently, so do not add a profiler zone. @@ -678,7 +931,7 @@ class MMScaleDemoteAdd { const hn::Rebind dc; using VD = hn::Vec; using VF = hn::Vec; - const size_t ND = hn::Lanes(dd); + HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); @@ -796,7 +1049,7 @@ class MMScaleDemoteAdd { const hn::Rebind dc; using VD = hn::Vec; using VF = hn::Vec; - const size_t ND = hn::Lanes(dd); + HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; @@ -858,41 +1111,51 @@ class MMScaleDemoteAdd { // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. // Its member variables avoid long argument lists in Do*(). class MMPerPackage { - public: + // Decompression is only required for F32 A and native BF16 dot products. + // If A is already BF16, we can use a view. Padding is not required + // because `LoopKC` can handle non-vector multiples. `LoopKC` also contains + // a special case for F32 `A` and non-native BF16 dot products. template - MMPerPackage(const MatPtrT& A, const MMArgs& args, const MMConfig& config, + static constexpr bool WantDecompressA() { + return HWY_NATIVE_DOT_BF16 && IsF32(); + } + + public: + MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, size_t pkg_idx, const IndexRange& range_np) : args_(args), pkg_idx_(pkg_idx), - // May be overwritten with a view of A, if already BF16. - A_(args_.env->storage.A(pkg_idx, A.Extents())), range_np_(range_np), mr_(config.MR()), - ranges_mc_(config.RangesOfMC(A.Rows())), - ranges_kc_(config.RangesOfKC(A.Cols())), + ranges_mc_(config.RangesOfMC(A.rows)), + ranges_kc_(config.RangesOfKC(A.cols)), ranges_nc_(config.RangesOfNC(range_np)), order_(config.Order()), inner_tasks_(config.InnerTasks()), out_(config.Out()), - line_bytes_(args.env->ctx.allocator.LineBytes()) { - A_ = DecompressA(A); + line_bytes_(args.env->ctx.allocator.LineBytes()) {} + + // The size of `A` that will actually be used, for purposes of choosing the + // autotuning candidates. Keep in sync with the `operator()` logic below. + template + static constexpr size_t ABytes() { + return WantDecompressA() ? sizeof(BF16) : sizeof(TA); } - // B is decompressed several call layers lower, but not all member functions - // depend on TB, so pass it as an argument instead of templating the class. - template - HWY_NOINLINE void operator()(const MatPtrT& B, RowPtrs C_rows) const { - switch (order_) { - case MMOrder::kNT: - return DoNT(B, C_rows); - case MMOrder::kNT_K: - return DoNT_K(B, C_rows); - case MMOrder::kNT_MT: - return DoNT_MT(B, C_rows); - case MMOrder::kNT_MT_K: - return DoNT_MT_K(B, C_rows); - default: - HWY_UNREACHABLE; + // B and maybe A are decompressed several call layers lower, but not all + // member functions depend on TA/TB, so pass them as an argument instead of + // templating the class. + template + HWY_NOINLINE void operator()(const MatPtrT& A, const MatPtrT& B, + RowPtrs C_rows) const { + if constexpr (WantDecompressA()) { + const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); + DecompressA(A, A_view); + constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. + DispatchOrder(A_view, A_padded, B, C_rows); + } else { + const bool A_padded = HasPadding(A); + DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows); } } @@ -909,16 +1172,57 @@ class MMPerPackage { return HWY_MAX(kNR, line_bytes_ / sizeof_TC); } + // Use instead of `MatPtr::IsPacked` because that returns true for single + // rows, but we want to know whether there is padding. + static bool HasPadding(const MatPtr& mat) { + return mat.Stride() > mat.Cols(); + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A`` + // and `B` are const, but StridedView is also used for non-const `partial`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag()); + (void)N; + // If `AB` is padded, then `LoopKC` expects the view is either a vector + // multiple, or all columns and thus also padded. + HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols())); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. + template + HWY_INLINE void DispatchOrder(const StridedView A, const bool A_padded, + const MatPtrT& B, + RowPtrs C_rows) const { + switch (order_) { + case MMOrder::kNT: + return DoNT(A, A_padded, B, C_rows); + case MMOrder::kNT_K: + return DoNT_K(A, A_padded, B, C_rows); + case MMOrder::kNT_MT: + return DoNT_MT(A, A_padded, B, C_rows); + case MMOrder::kNT_MT_K: + return DoNT_MT_K(A, A_padded, B, C_rows); + default: + HWY_UNREACHABLE; + } + } + // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); - const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K); + const StridedView A_view = A.View(range_M.begin(), 0, K); const size_t B_stride = Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); @@ -936,8 +1240,8 @@ class MMPerPackage { row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), - args_, C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K, + MMSetC(), args_, C_rows); } }); @@ -945,8 +1249,9 @@ class MMPerPackage { } // Single M range, parallel N, sequential K. Fills all of partial. - template - HWY_INLINE void DoNT_K(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_K(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); const IndexRange& range_mc = ranges_mc_.Range(0); @@ -958,8 +1263,8 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedViewBF& A_view = - A_.View(range_mc.begin(), range_kc.begin(), kc); + const StridedView A_view = + A.View(range_mc.begin(), range_kc.begin(), kc); const StridedViewBF B_storage_view( B_storage, kc, Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); @@ -967,8 +1272,8 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, + out_tag, args_, C_rows); } }; @@ -1013,8 +1318,9 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); @@ -1031,7 +1337,7 @@ class MMPerPackage { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); - const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K); + const StridedView A_view = A.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS const StridedViewBF B_storage_view(B_storage, K, B_stride); @@ -1039,8 +1345,8 @@ class MMPerPackage { row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), - args_, C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, + MMSetC(), args_, C_rows); } }); @@ -1049,8 +1355,9 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT_K(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); static const auto fill_zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC"); @@ -1067,14 +1374,14 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedViewBF& A_view = - A_.View(range_mc.begin(), range_kc.begin(), kc); + const StridedView A_view = + A.View(range_mc.begin(), range_kc.begin(), kc); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, + out_tag, args_, C_rows); } }; // loop_nc args_.env->parallel.ForRangesMC_NC( @@ -1107,17 +1414,16 @@ class MMPerPackage { }); } - // Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a - // seekable type (i.e., not NUQ) so we can use pointer arithmetic. - template - HWY_NOINLINE void DoDecompressA(const MatPtrT& A, MMParA par_a) const { + // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + HWY_NOINLINE void DoDecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMParA par_a) const { const IndexRange all_M(0, A.Rows()); const IndexRange all_K(0, A.Cols()); - HWY_DASSERT(all_K.Num() == A_.Cols()); + HWY_DASSERT(all_K.Num() == A_view.Cols()); const hn::ScalableTag dbf; const size_t NBF = hn::Lanes(dbf); - static_assert(hwy::IsSameEither(), "Can seek"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA"); @@ -1133,8 +1439,9 @@ class MMPerPackage { // otherwise `DecompressAndZeroPad` overwrites neighbors. HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); for (size_t row_a : range_M) { - const PackedSpan from = MakeSpan(A.Row(row_a) + col0, cols); - BF16* HWY_RESTRICT to = A_.Row(row_a) + col0; + const PackedSpan from = + MakeSpan(A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; DecompressAndZeroPad(dbf, from, 0, to, cols); // Verify that we zero-padded. if constexpr (HWY_IS_DEBUG_BUILD) { @@ -1175,23 +1482,12 @@ class MMPerPackage { } // Autotuning wrapper for `DoDecompressA`. - template - HWY_INLINE StridedViewBF DecompressA(const MatPtrT& A) const { + HWY_INLINE void DecompressA(const MatPtrT& A, + const StridedViewBF A_view) const { MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; - // If already BF16, maybe return a view: - if constexpr (hwy::IsSame()) { - // Only if vector multiple and padded (see `DoDecompressA`). - const size_t NBF = hn::Lanes(hn::ScalableTag()); - if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) { - // Const, but cast because StridedView is also used for `partial` which - // is non-const. - return StridedViewBF(const_cast(A.Row(0)), A.Cols(), A.Stride()); - } - } if (HWY_LIKELY(autotune.Best())) { - DoDecompressA(A, *autotune.Best()); - return A_; + return DoDecompressA(A, A_view, *autotune.Best()); } // First call: generate candidates. @@ -1204,7 +1500,7 @@ class MMPerPackage { const MMParA& par_a = autotune.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, par_a); + DoDecompressA(A, A_view, par_a); const uint64_t t1 = args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); @@ -1213,7 +1509,6 @@ class MMPerPackage { static_cast(min_elapsed) / hwy::platform::InvariantTicksPerSecond() * 1E6); } - return A_; } // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, @@ -1223,12 +1518,17 @@ class MMPerPackage { HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, const IndexRange& range_kc, const StridedViewBF& B_view) const { + const hn::ScalableTag dbf; + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); + + // View() is safe if vector multiple, or padded: for the latter, `ZeroInit` + // and weights.cc zero-initialize the padding. if constexpr (hwy::IsSame()) { - return StridedViewBF(const_cast(B.Row(row_b)) + range_kc.begin(), - range_kc.Num(), B.Stride()); + if (B.Cols() % NBF == 0 || HasPadding(B)) { + return View(B, row_b, range_kc.begin(), range_kc.Num()); + } } - const hn::ScalableTag dbf; const PackedSpan B_span = B.PaddedSpan(); const size_t kc = range_kc.Num(); @@ -1240,7 +1540,7 @@ class MMPerPackage { DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); // Verify that we zero-padded. if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, hn::Lanes(dbf)); ++i) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); } } @@ -1250,7 +1550,6 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; - StridedViewBF A_; // view into A or pkg_A_, both of which are padded. const IndexRange range_np_; // From MMConfig: @@ -1293,13 +1592,14 @@ struct MMImpl { MMZone mm_zone; mm_zone.MaybeEnter(pkg_idx, zone, args); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, + C_rows); }); } else { const size_t pkg_idx = 0; HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows); } } }; @@ -1310,7 +1610,7 @@ struct MMImpl { // `K = B.Cols()`, which must match `A.Cols()`, is the number // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // are no other restrictions on shape, though performance is better when `M % 4 -// == 0` or `M <= 4`, and when A is padded (`!A.IsPacked()`). +// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()). // // NOTE: if A and/or B are BF16 and padded, the interval `[Cols(), // hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match @@ -1376,8 +1676,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N % kNR == 0); - tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, - kNR, per_key.ranges_np, env.print_config)); + tuner.SetCandidates( + MMCandidates(allocator, M, K, N, MMPerPackage::ABytes(), sizeof(TC), + kMaxMR, kNR, per_key.ranges_np, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index c51acbd..c9ddfb6 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -64,19 +64,21 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, class GenerateCandidates { public: GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, - size_t sizeof_TC, size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, bool print_config) + size_t sizeof_TA, size_t sizeof_TC, size_t max_mr, + size_t nr, const IndexRangePartition& ranges_np, + bool print_config) : allocator_(allocator), M_(M), K_(K), N_(N), + sizeof_TA_(sizeof_TA), sizeof_TC_(sizeof_TC), max_mr_(max_mr), nr_(nr), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line // is likely still in L1, but we expect K > 1000 and might as well round - // up to the line size. + // up to the line size. Use BF16, not sizeof_TA, because B is BF16. kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), nc_multiple_(allocator.StepBytes() / sizeof_TC), ranges_np_(ranges_np), @@ -176,8 +178,9 @@ class GenerateCandidates { // subtract the output and buf, and allow using more than the actual L1 // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. - const size_t bytes_ab = allocator_.L1Bytes() * 3; - const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); + const size_t bytes_ab = + allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream)); + const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); kc_max = RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); @@ -224,7 +227,7 @@ class GenerateCandidates { // packed B. We want `mc * kc` elements of A to fit in L2, alongside // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // partial. - const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes(); + const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes(); size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); HWY_DASSERT(mc_max != 0); @@ -359,6 +362,7 @@ class GenerateCandidates { const size_t M_; const size_t K_; const size_t N_; + const size_t sizeof_TA_; const size_t sizeof_TC_; const size_t max_mr_; @@ -376,12 +380,12 @@ class GenerateCandidates { // Facade to avoid exposing `GenerateCandidates` in the header. std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TA, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) { - return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, - ranges_np, print_config)(); + return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr, + nr, ranges_np, print_config)(); } // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote diff --git a/ops/matmul.h b/ops/matmul.h index de8ef8c..99290c1 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -281,7 +281,9 @@ class MMStorage { BindC(partial_storage_, parallel); } - // Returns per-package matrix view. + // Returns per-package matrix view. Converting A=F32 to BF16 up-front is + // faster than on-the-fly when native BF16 is available: it only happens once, + // not per B tile row, and the cache footprint is smaller. StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.cols <= kMaxK); @@ -475,8 +477,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TA, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 122012e..aadbc56 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -118,17 +118,26 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch); const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); + // Dot() uses double-precision summation. double tolerance = 12 * norm * eps_f32; - // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the - // tolerance there. - if (IsF32() && IsF32()) { - tolerance += 4 * max_abs * eps_bf16; + // If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to + // BF16, so add extra tolerance. + if (IsF32()) { + tolerance += 2 * max_abs * eps_bf16; } + if (tolerance > 500.0) { HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); } - const double max_rel = 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); + const double rel_tolerance = + 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); + double max_rel = 0.0; + size_t worst_r = 0; + size_t worst_c = 0; + double worst_actual = 0.0; + double worst_expected = 0.0; + size_t num_outside = 0; for (size_t r = 0; r < A.Rows(); r++) { const float* expected_row = c_slow_batch.Row(r); const float* actual_row = c_batch.Row(r); @@ -143,15 +152,24 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const double min = HWY_MIN(expected_value, actual_value); const double rel = max / HWY_MAX(min, 1E-6); if (rel > max_rel) { - hwy::Abort(__FILE__, line, - "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " - "tolerance %f rel %E max_rel %E\n", - r, c, expected_value, actual_value, norm, max_abs, - tolerance, rel, max_rel); + worst_expected = expected_value; + worst_actual = actual_value; + worst_r = r; + worst_c = c; + max_rel = rel; + ++num_outside; } } } } + + if (max_rel > rel_tolerance) { + hwy::Abort(__FILE__, line, + "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " + "tolerance %f rel %E max_rel %E num_outside %zu\n", + worst_r, worst_c, worst_expected, worst_actual, norm, max_abs, + tolerance, max_rel, rel_tolerance, num_outside); + } } // B is already transposed. @@ -188,9 +206,9 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, TC* HWY_RESTRICT C_row = C.Row(r); for (size_t c : cols_c) { const float add = add_row ? add_row[c] : 0.0f; - C_row[c] = hwy::ConvertScalarTo( - add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r), - A.Cols())); + const float dot = + Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols()); + C_row[c] = hwy::ConvertScalarTo(add + scale * dot); } } }); @@ -279,6 +297,9 @@ void TestTiny() { for (size_t K = 1; K <= 64; K *= 2) { for (size_t N = 4; N <= 64; N += max_packages * 4) { TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); } } } @@ -334,6 +355,10 @@ void TestAllMatMul() { TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + // Non-vector-multiple K. + TestMatMul(128, 258, 128, /*add=*/true, env, __LINE__); + TestMatMul(128, 258, 128, /*add=*/true, env, __LINE__); + // minimal non-square test. kColsARowsB must be at least 2 vectors. TestMatMul(35, 128, 32, /*add=*/false, env, __LINE__); TestMatMul(34, 128, 32, /*add=*/true, env, __LINE__); From 72888914391c95d8453807d0d35b8a34f955686b Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 29 Aug 2025 00:11:31 -0700 Subject: [PATCH 11/65] Remove F64 partial storage in matmul. Also remove no longer used kMaxN; row_ptrs only used for C PiperOrigin-RevId: 800774757 --- gemma/attention.cc | 4 +- ops/matmul-inl.h | 488 ++++----------------------------------------- ops/matmul.cc | 50 ++--- ops/matmul.h | 75 ++----- 4 files changed, 69 insertions(+), 548 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index a04b868..c73abcb 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -275,10 +275,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t cache_pos = activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); - env.row_ptrs[2][interleaved_idx] = reinterpret_cast( + env.row_ptrs[0][interleaved_idx] = reinterpret_cast( qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); } - kv_rows.AttachRowPtrs(env.row_ptrs[2].get()); + kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 9f279cb..56cb06f 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -71,23 +71,29 @@ static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { #endif } -// Converts from float intermediate to MatMul output type `TC`. -template , HWY_IF_F32_D(DC)> -hn::Vec TCFromF32(DC /*dc*/, hn::Vec vf) { +// Converts from float intermediate to/from MatMul output type `TC`. +template +hn::Vec TCFromF32(DC /*dc*/, hn::Vec vf) { return vf; } template , HWY_IF_BF16_D(DC)> hn::Vec TCFromF32(DC dc, hn::Vec vf) { return hn::DemoteTo(dc, vf); } +template +hn::Vec F32FromTC(DC /*dc*/, hn::Vec vc) { + return vc; +} +template , HWY_IF_BF16_D(DC)> +hn::Vec F32FromTC(DC dc, hn::Vec vc) { + return hn::PromoteTo(DF(), vc); +} // Tag classes, passed to `MMKernel::A2C0` to choose between writing one -// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the -// first kc result to partial, or accumulating the next kc result into partial -// via `MMAddHorizontalSumsIntoPartial`. +// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or accumulating the +// next kc result into `C`. struct MMSetC {}; -struct MMSetPartial {}; -struct MMAddPartial {}; +struct MMAddC {}; // Stores horizontal sums of up to 16 vectors via transpose. template @@ -143,10 +149,8 @@ class MMStoreHorizontalSumsIntoC { sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane); } const V4 vscale = hn::Set(d4, args.scale); - V4 vadd = hn::Zero(d4); - if constexpr (kAdd) { - vadd = hn::Load(d4, args.add + col_c); - } + HWY_ALIGN static constexpr float kZero[4] = {}; + const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c); @@ -195,156 +199,16 @@ class MMStoreHorizontalSumsIntoC { if constexpr (kRow < kRowsAC) { TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; const hn::Rebind dc4; + if constexpr (kAdd) { + vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value + } // else: add bias (only once, the first time we store to C) + const VF4 out = hn::MulAdd(sum, vscale, vadd); hn::Store(TCFromF32(dc4, out), dc4, pos); } } }; // MMStoreHorizontalSumsIntoC -// Accumulates horizontal sums of up to 16 vectors via transpose. -template -class MMAddHorizontalSumsIntoPartial { - public: - static_assert(kNR == 4); // for `StoreInterleaved4` - - // Computes horizontal sums of `kRowsAC x kNR` vectors and accumulates - // into `partial` starting at `(row_c, col_c)`. - // - // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a - // transposed B row vector indexed by `c`. Their elements are thus a subset - // of the terms of the dot product constituting the final `C[r, c]` result. - // Thus we compute the horizontal sums of each `Crc`. The elements may be - // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but - // this does not change their horizontal sum. - template > - HWY_INLINE void operator()(DF df, // - VF F00, VF F01, VF F02, VF F03, // - VF F10, VF F11, VF F12, VF F13, // - VF F20, VF F21, VF F22, VF F23, // - VF F30, VF F31, VF F32, VF F33, // - const size_t row_c, const size_t col_c, - const StridedViewD& partial) const { - // We accumulate in 64-bit to avoid loss of precision. - static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64"); - - const hn::Repartition dd; - HWY_ALIGN double buf[16 * hn::MaxLanes(dd)]; - using VD = hn::Vec; - HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); - VD C00 = SumOfPromotedPairs(dd, F00); - VD C01 = SumOfPromotedPairs(dd, F01); - VD C02 = SumOfPromotedPairs(dd, F02); - VD C03 = SumOfPromotedPairs(dd, F03); - VD C10 = SumOfPromotedPairs(dd, F10); - VD C11 = SumOfPromotedPairs(dd, F11); - VD C12 = SumOfPromotedPairs(dd, F12); - VD C13 = SumOfPromotedPairs(dd, F13); - VD C20 = SumOfPromotedPairs(dd, F20); - VD C21 = SumOfPromotedPairs(dd, F21); - VD C22 = SumOfPromotedPairs(dd, F22); - VD C23 = SumOfPromotedPairs(dd, F23); - VD C30 = SumOfPromotedPairs(dd, F30); - VD C31 = SumOfPromotedPairs(dd, F31); - VD C32 = SumOfPromotedPairs(dd, F32); - VD C33 = SumOfPromotedPairs(dd, F33); - - // Horizontal reductions (`ReduceSum`) are rather expensive, entailing - // log(N) operations for vectors of length N. Because `kNR` == 4, we - // instead use `StoreInterleaved4` for a vector length-agnostic - // 'transpose': `buf[0, 4 * N)` holds `C00[0], C01[0], C02[0], C03[0], - // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], - // C03[N-1]`. - MaybeStoreInterleaved4<0>(dd, ND, C00, C01, C02, C03, buf); - MaybeStoreInterleaved4<1>(dd, ND, C10, C11, C12, C13, buf); - MaybeStoreInterleaved4<2>(dd, ND, C20, C21, C22, C23, buf); - MaybeStoreInterleaved4<3>(dd, ND, C30, C31, C32, C33, buf); - // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in - // the elements of one V4. We have four independent rows `r`, hence the - // code is effectively unrolled, which increases throughput. - const hn::CappedTag d4; - using V4 = hn::Vec; - // Store to four elements per row of `partial`. - // Loop is required because vectors may be smaller than 4*64 bits. - for (size_t c = 0; c < kNR; c += hn::Lanes(d4)) { - V4 sum0 = MaybeLoad<0>(d4, ND, buf + c); - V4 sum1 = MaybeLoad<1>(d4, ND, buf + c); - V4 sum2 = MaybeLoad<2>(d4, ND, buf + c); - V4 sum3 = MaybeLoad<3>(d4, ND, buf + c); - - for (size_t lane = 1; lane < ND; ++lane) { - sum0 = MaybeAdd<0>(d4, ND, sum0, buf + c + kNR * lane); - sum1 = MaybeAdd<1>(d4, ND, sum1, buf + c + kNR * lane); - sum2 = MaybeAdd<2>(d4, ND, sum2, buf + c + kNR * lane); - sum3 = MaybeAdd<3>(d4, ND, sum3, buf + c + kNR * lane); - } - MaybeAddStore<0>(d4, sum0, partial, row_c, col_c + c); - MaybeAddStore<1>(d4, sum1, partial, row_c, col_c + c); - MaybeAddStore<2>(d4, sum2, partial, row_c, col_c + c); - MaybeAddStore<3>(d4, sum3, partial, row_c, col_c + c); - } - } - - private: - // Converts lanes to double and adds pairs of them to obtain a vector with the - // same horizontal sum, but element type double. - template , - class DF = hn::Repartition, class VF = hn::Vec> - static HWY_INLINE VD SumOfPromotedPairs(DD dd, VF f) { - // TODO: SVE could PromoteEvenTo. - const VD d0 = hn::PromoteLowerTo(dd, f); - const VD d1 = hn::PromoteUpperTo(dd, f); - return hn::Add(d0, d1); - } - - // These helper functions hoist if() out of the main code below. They have - // no effect if kRow >= kRowsAC. - template > - static HWY_INLINE void MaybeStoreInterleaved4(DD dd, size_t N, VD Cr0, VD Cr1, - VD Cr2, VD Cr3, - double* HWY_RESTRICT buf) { - if constexpr (kRow < kRowsAC) { - hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, dd, buf + 4 * kRow * N); - } - } - - // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. - template > - static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N, - const double* HWY_RESTRICT buf) { - if constexpr (kRow < kRowsAC) { - return hn::Load(d4, buf + 4 * kRow * N); - } else { - return hn::Zero(d4); - } - } - - template > - static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum, - const double* HWY_RESTRICT buf) { - if constexpr (kRow < kRowsAC) { - return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N)); - } else { - return sum; - } - } - - template > - static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, - const StridedViewD& partial, - const size_t row_c, const size_t col_c) { - if constexpr (kRow < kRowsAC) { - double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c; - if constexpr (hwy::IsSame()) { - hn::Store(sum, d4, pos); - } else { - static_assert(hwy::IsSame()); - const V4 prev = hn::Load(d4, pos); - hn::Store(hn::Add(sum, prev), d4, pos); - } - } - } -}; // MMAddHorizontalSumsIntoPartial - // Stateless, wraps member functions. class MMKernel { public: @@ -865,247 +729,18 @@ class MMKernel { // called frequently, so do not add a profiler zone. if constexpr (hwy::IsSame()) { - if (args.add) { - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } else { - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } - } else { - MMAddHorizontalSumsIntoPartial()( + MMStoreHorizontalSumsIntoC()( df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args.partial); + C31, C32, C33, row_ac, col_c, args, C_rows); + } else { + static_assert(hwy::IsSame()); + MMStoreHorizontalSumsIntoC()( + df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, + C31, C32, C33, row_ac, col_c, args, C_rows); } } }; -// Multiply partial by scale, add bias if present, demote and store to f32 `C`. -// Stateless, wraps member functions. -class MMScaleDemoteAdd { - public: - // Fills the `range_mc/range_nc` region of `outputs.C` by multiplying the - // same region of `outputs.partial` by `outputs.scale`, which is the product - // of the scales of A and B, demoting from f64 to f32, then if `outputs.add` - // is nonzero, adding it to each row. - // TODO: fuse with subsequent operations - function pointer? - // Although this region in `outputs.C` is not touched again, streaming stores - // do not help on SKX and Zen4. TODO: re-check this. - template - static HWY_INLINE void FillC(const IndexRange& range_mc, - const IndexRange& range_nc, const MMArgs& args, - RowPtrs C_rows) { - size_t row_c = range_mc.begin(); - if (args.add) { - constexpr bool kAdd = true; - if (range_mc.Num() >= 4) { - for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args, C_rows); - } - } - for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args, C_rows); - } - } else { - constexpr bool kAdd = false; - if (range_mc.Num() >= 4) { - for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args, C_rows); - } - } - for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args, C_rows); - } - } - } - - private: - // Unrolled for 4 rows to reduce the number of loads from `add`. - template - static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, RowPtrs C_rows) { - const hn::ScalableTag dd; - const hn::Rebind df; // result of DemoteTo - const hn::Rebind dc; - using VD = hn::Vec; - using VF = hn::Vec; - HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); - const VD vscale = hn::Set(dd, args.scale); - - const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); - const double* HWY_RESTRICT pr1 = args.partial.Row(row_c + 1); - const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); - const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); - - TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; - TC* HWY_RESTRICT cr1 = C_rows[row_c + 1]; - TC* HWY_RESTRICT cr2 = C_rows[row_c + 2]; - TC* HWY_RESTRICT cr3 = C_rows[row_c + 3]; - - // We manually unroll 2x for higher IPC in batch=1. - size_t col_c = range_nc.begin(); - if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { - for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { - VD a0, a1; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c)); - a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND)); - } - const VD d00 = hn::Load(dd, pr0 + col_c); - const VD d01 = hn::Load(dd, pr0 + col_c + ND); - const VD d10 = hn::Load(dd, pr1 + col_c); - const VD d11 = hn::Load(dd, pr1 + col_c + ND); - const VD d20 = hn::Load(dd, pr2 + col_c); - const VD d21 = hn::Load(dd, pr2 + col_c + ND); - const VD d30 = hn::Load(dd, pr3 + col_c); - const VD d31 = hn::Load(dd, pr3 + col_c + ND); - VD m00, m01, m10, m11, m20, m21, m30, m31; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - m01 = hn::MulAdd(d01, vscale, a1); - m10 = hn::MulAdd(d10, vscale, a0); - m11 = hn::MulAdd(d11, vscale, a1); - m20 = hn::MulAdd(d20, vscale, a0); - m21 = hn::MulAdd(d21, vscale, a1); - m30 = hn::MulAdd(d30, vscale, a0); - m31 = hn::MulAdd(d31, vscale, a1); - } else { - m00 = hn::Mul(d00, vscale); - m01 = hn::Mul(d01, vscale); - m10 = hn::Mul(d10, vscale); - m11 = hn::Mul(d11, vscale); - m20 = hn::Mul(d20, vscale); - m21 = hn::Mul(d21, vscale); - m30 = hn::Mul(d30, vscale); - m31 = hn::Mul(d31, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - const VF f01 = hn::DemoteTo(df, m01); - const VF f10 = hn::DemoteTo(df, m10); - const VF f11 = hn::DemoteTo(df, m11); - const VF f20 = hn::DemoteTo(df, m20); - const VF f21 = hn::DemoteTo(df, m21); - const VF f30 = hn::DemoteTo(df, m30); - const VF f31 = hn::DemoteTo(df, m31); - // Note that Stream is neutral on SKX and harmful on Zen4. - hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c); - hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND); - hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c); - hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND); - hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c); - hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND); - hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c); - hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND); - } - } - - for (; col_c < range_nc.end(); col_c += ND) { - const size_t remaining = range_nc.end() - col_c; - HWY_DASSERT(remaining < 2 * ND); - - VD a0; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining)); - } - const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining); - const VD d10 = hn::LoadN(dd, pr1 + col_c, remaining); - const VD d20 = hn::LoadN(dd, pr2 + col_c, remaining); - const VD d30 = hn::LoadN(dd, pr3 + col_c, remaining); - VD m00, m10, m20, m30; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - m10 = hn::MulAdd(d10, vscale, a0); - m20 = hn::MulAdd(d20, vscale, a0); - m30 = hn::MulAdd(d30, vscale, a0); - } else { - m00 = hn::Mul(d00, vscale); - m10 = hn::Mul(d10, vscale); - m20 = hn::Mul(d20, vscale); - m30 = hn::Mul(d30, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - const VF f10 = hn::DemoteTo(df, m10); - const VF f20 = hn::DemoteTo(df, m20); - const VF f30 = hn::DemoteTo(df, m30); - hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining); - hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining); - hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining); - hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining); - } - } - - // Same as above but handles a single row (for remainder rows). - template - static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, RowPtrs C_rows) { - const hn::ScalableTag dd; - const hn::Rebind df; // result of DemoteTo - const hn::Rebind dc; - using VD = hn::Vec; - using VF = hn::Vec; - HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); - const VD vscale = hn::Set(dd, args.scale); - const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); - TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; - - // We manually unroll 2x for higher IPC in batch=1. - size_t col_c = range_nc.begin(); - if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { - for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { - VD a0, a1; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c)); - a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND)); - } - const VD d00 = hn::Load(dd, pr0 + col_c); - const VD d01 = hn::Load(dd, pr0 + col_c + ND); - VD m00, m01; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - m01 = hn::MulAdd(d01, vscale, a1); - } else { - m00 = hn::Mul(d00, vscale); - m01 = hn::Mul(d01, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - const VF f01 = hn::DemoteTo(df, m01); - // Note that Stream is neutral on SKX and harmful on Zen4. - hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c); - hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND); - } - } - - for (; col_c < range_nc.end(); col_c += ND) { - const size_t remaining = range_nc.end() - col_c; - HWY_DASSERT(remaining < 2 * ND); - - VD a0; // unused if !kAdd - if constexpr (kAdd) { - // Promoting to double lets us fuse the Add into MulAdd. - a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining)); - } - const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining); - VD m00; - if constexpr (kAdd) { - m00 = hn::MulAdd(d00, vscale, a0); - } else { - m00 = hn::Mul(d00, vscale); - } - // First convert f64 to f32. - const VF f00 = hn::DemoteTo(df, m00); - hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining); - } - } -}; // MMScaleDemoteAdd - // Called on the main thread with the entire N range, or by each package with // a static partition of N. This class contains several variants of the // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. @@ -1132,7 +767,6 @@ class MMPerPackage { ranges_nc_(config.RangesOfNC(range_np)), order_(config.Order()), inner_tasks_(config.InnerTasks()), - out_(config.Out()), line_bytes_(args.env->ctx.allocator.LineBytes()) {} // The size of `A` that will actually be used, for purposes of choosing the @@ -1244,11 +878,9 @@ class MMPerPackage { MMSetC(), args_, C_rows); } }); - - HWY_DASSERT(out_ == MMOut::kDirect); // already filled C } - // Single M range, parallel N, sequential K. Fills all of partial. + // Single M range, parallel N, sequential K. Sets C, then accumulates. template HWY_INLINE void DoNT_K(const StridedView A, const bool A_padded, const MatPtrT& B, RowPtrs C_rows) const { @@ -1288,32 +920,12 @@ class MMPerPackage { // Peel off the first iteration of the kc loop: avoid // zero-initializing `partial` by writing into it. ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMSetPartial()); + loop_nc(B_storage, range_kc, range_nc, MMSetC()); }); ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMAddPartial()); + loop_nc(B_storage, range_kc, range_nc, MMAddC()); }); }); - - if (out_ == MMOut::kCopy) { - static const auto zone = - args_.env->ctx.profiler.AddZone("MM.NT_K.FillC.Copy"); - MMZone fill_zone; - fill_zone.MaybeEnter(0, zone, args_); - MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows); - } else if (out_ == MMOut::kParM) { - static const auto zone = - args_.env->ctx.profiler.AddZone("MM.NT_K.FillC.ParM"); - args_.env->parallel.ForRangeMC( - range_mc, pkg_idx_, [&](size_t row_a, size_t worker) HWY_ATTR { - MMZone fill_zone; - fill_zone.MaybeEnter(worker, zone, args_); - MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, - args_, C_rows); - }); - } else { - HWY_UNREACHABLE; // kDirect is only used with kNT. - } } // Parallel loops over mc/nc blocks of M/range_np, single K. @@ -1343,14 +955,12 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = + const StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, MMSetC(), args_, C_rows); } }); - - HWY_DASSERT(out_ == MMOut::kDirect); // already filled C } // Parallel loops over mc/nc blocks of M/range_np, sequential K. @@ -1359,8 +969,6 @@ class MMPerPackage { HWY_INLINE void DoNT_MT_K(const StridedView A, const bool A_padded, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); - static const auto fill_zone = - args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); const size_t B_stride = @@ -1395,22 +1003,13 @@ class MMPerPackage { const StridedViewBF B_storage_view(B_storage, kc_max, B_stride); // Peel off the first iteration of the kc loop: avoid - // zero-initializing `partial` by writing into it. + // zero-initializing `C` by writing into it. ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, - MMSetPartial()); + loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMSetC()); }); ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, - MMAddPartial()); + loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMAddC()); }); - - // Already in parallel section, hence no `kParM`, and - // `kDirect` is only used with `kNT_MT`. - HWY_DASSERT(out_ == MMOut::kCopy); - MMZone fill_mm_zone; - fill_mm_zone.MaybeEnter(worker, fill_zone, args_); - MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows); }); } @@ -1559,7 +1158,6 @@ class MMPerPackage { const IndexRangePartition ranges_nc_; const MMOrder order_; const size_t inner_tasks_; - const MMOut out_; const size_t line_bytes_; }; // MMPerPackage @@ -1632,7 +1230,7 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C) { - RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[2]); + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); @@ -1659,7 +1257,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), - add, env.storage.Partial()); + add); if (HWY_LIKELY(tuner.Best())) { MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best()); return &per_key; @@ -1673,7 +1271,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(K == B.Cols()); HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(K <= MMStorage::kMaxK); - HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N % kNR == 0); tuner.SetCandidates( @@ -1690,10 +1287,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, hwy::platform::InvariantTicksPerSecond(); const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { - fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s\n", flops * 1E-9, + fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9, min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), - StringFromOrder(cfg.Order()), cfg.InnerTasks(), - StringFromOut(cfg.Out())); + StringFromOrder(cfg.Order()), cfg.InnerTasks()); } if (HWY_UNLIKELY(env.print_best && tuner.Best())) { const auto ratio = [per_key](uint64_t ticks) -> double { @@ -1702,11 +1298,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, }; const MMConfig& best = *tuner.Best(); fprintf(stderr, - "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s,%.2f,%.2f\n", - M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), + "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", M, + K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), best.KC(), best.NC(), StringFromOrder(best.Order()), - best.InnerTasks(), StringFromOut(best.Out()), - ratio(tuner.WorstMinTicks()), ratio(tuner.FirstConfigTicks())); + best.InnerTasks(), ratio(tuner.WorstMinTicks()), + ratio(tuner.FirstConfigTicks())); } return &per_key; diff --git a/ops/matmul.cc b/ops/matmul.cc index c9ddfb6..71f2efe 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -91,24 +91,21 @@ class GenerateCandidates { for (size_t mr : MR()) { for (MMOrder order : Orders(mr)) { const std::vector& all_inner_tasks = InnerTasks(order); - const std::vector& all_outs = Outs(order); for (size_t kc : KC(mr, order)) { for (size_t mc : MC(mr, kc, order)) { for (size_t nc : NC(mr, mc, kc, order)) { for (int inner_tasks : all_inner_tasks) { - for (MMOut out : all_outs) { - const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, - nc_multiple_, order, out, inner_tasks); - const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); - const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); + const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, + nc_multiple_, order, inner_tasks); + const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); + const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); - // Blocks only make sense when there are multiple M tasks. - if (IsBlock(order) != (M_tasks > 1)) continue; - // Single KC only makes sense when there is a single K task. - if (IsOneKC(order) != (K_tasks == 1)) continue; + // Blocks only make sense when there are multiple M tasks. + if (IsBlock(order) != (M_tasks > 1)) continue; + // Single KC only makes sense when there is a single K task. + if (IsOneKC(order) != (K_tasks == 1)) continue; - candidates.push_back(config); - } + candidates.push_back(config); } } } @@ -265,14 +262,13 @@ class GenerateCandidates { SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { const size_t np_max = ranges_np_.TaskSize(); size_t nc_max = np_max; - const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double); // Only if there will be reuse of B: choose the largest `nc_max` (C cols) // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. // Otherwise, leave it unbounded. if (M_ > mr) { - const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes); - nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc); - nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max); + const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); + nc_max = + HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), np_max); } HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); @@ -340,24 +336,6 @@ class GenerateCandidates { return inner_tasks; } - // Whether to parallelize FillC or enable direct writes to C. - std::vector Outs(MMOrder order) const { - std::vector outs; - for (size_t out_idx = 0;; ++out_idx) { - const MMOut out = static_cast(out_idx); - if (StringFromOut(out) == nullptr) return outs; // done - // kParM only makes sense if we have more than one row of A. - if (out == MMOut::kParM && M_ == 1) continue; - // Blocks are already parallelized. - if (out == MMOut::kParM && IsBlock(order)) continue; - // Direct only works for a single kc range. - if ((out == MMOut::kDirect) != IsOneKC(order)) continue; - // For non-block, kCopy does not beat kDirect. - if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue; - outs.push_back(out); - } - } - const Allocator& allocator_; const size_t M_; const size_t K_; @@ -432,8 +410,6 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // A - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxN)); // B row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C } @@ -461,7 +437,7 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { } } -// C is BF16/float, or double for partial +// C is BF16/float void BindC(MatPtr& C, MMParallel& parallel) { Allocator& allocator = parallel.allocator(); if (!allocator.ShouldBind()) return; diff --git a/ops/matmul.h b/ops/matmul.h index 99290c1..e4c436f 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -43,7 +43,7 @@ namespace gcpp { // at least the product of the FMA latency (3..5) times the throughput (2). // This and `mr` are limited by the number of registers, which is generally // 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in -// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`. +// `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`. constexpr size_t kNR = 4; // Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because @@ -195,7 +195,7 @@ class MMParallel { }; void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); -// C is BF16/float, or double for partial. +// C is BF16/float. void BindC(MatPtr& C, MMParallel& parallel); // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. @@ -236,8 +236,7 @@ class StridedView { using StridedViewBF = StridedView; using StridedViewD = StridedView; -// Per-package storage for packed A, and one global C-shaped `partial` for -// accumulating partial dot products (sections of K). +// Per-package storage for packed A. class MMStorage { public: // Compile-time bounds on matrix dimensions to enable pre-allocating storage @@ -245,21 +244,13 @@ class MMStorage { // per package and 512 MiB, respectively. static constexpr size_t kMaxM = 4096; static constexpr size_t kMaxK = 64 * 1024; - static constexpr size_t kMaxN = 256 * 1024; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. static constexpr size_t kMaxKC = 8 * 1024; // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). - MMStorage(const Allocator& allocator, MMParallel& parallel) - : // Per-worker copies of `partial` would be wasteful. We instead - // allocate one instance of the maximum matrix extents because threads - // write at false-sharing-free granularity. - partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator, - MatPadding::kOdd), - // Same stride independent of the actual C.Cols() so we can pre-bind. - partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { + MMStorage(const Allocator& allocator, MMParallel& parallel) { // Per-package allocation so each can decompress A into its own copy. // Must be padded, see `DoDecompressA`. parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) { @@ -276,9 +267,6 @@ class MMStorage { } } }); - - // Avoid cross-package accesses. - BindC(partial_storage_, parallel); } // Returns per-package matrix view. Converting A=F32 to BF16 up-front is @@ -291,12 +279,8 @@ class MMStorage { extents.cols, pkg_A_[pkg_idx]->Stride()); } - StridedViewD Partial() const { return partial_; } - private: std::unique_ptr> pkg_A_[kMaxPackages]; - MatStorageT partial_storage_; - StridedViewD partial_; }; //------------------------------------------------------------------------------ @@ -349,29 +333,6 @@ static inline const char* StringFromOrder(MMOrder order) { } } -// How/where to write the A2C0 result. This determines the `tag` argument to -// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or -// `MMAddHorizontalSumsIntoPartial`. -enum class MMOut : uint8_t { - kCopy, // accumulate into partial, scale/add to C - kDirect, // single kc task, write directly to C - kParM // kCopy but parallel over M - // kParN is not better on SKX/Zen4. -}; - -static inline const char* StringFromOut(MMOut out) { - switch (out) { - case MMOut::kDirect: - return "Direct"; - case MMOut::kCopy: - return "Copy"; - case MMOut::kParM: - return "ParM"; - default: - return nullptr; - } -} - // How to parallelize the per-package `DecompressA`. To reduce combinatorial // explosion, we tune this separately from `MMConfig`. enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM }; @@ -405,10 +366,9 @@ class MMConfig { MMConfig() = default; // for std::vector // `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `MMOrder` is how to parallelize the outer loops. - // `MMOut` is how/whether to parallelize filling the C result. // `inner_tasks` chooses the within-cluster task granularity in `ForNP`. MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, - size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out, + size_t kc_multiple, size_t nc_multiple, MMOrder order, int inner_tasks) : mr_(static_cast(mr)), mc_(static_cast(mc)), @@ -417,7 +377,6 @@ class MMConfig { nc_multiple_(static_cast(nc_multiple)), kc_multiple_(static_cast(kc_multiple)), order_(order), - out_(out), inner_tasks_(static_cast(inner_tasks)), reserved_{} { HWY_DASSERT(mr == 1 || mr == 2 || mr == 4); @@ -433,7 +392,6 @@ class MMConfig { HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple); } HWY_DASSERT(StringFromOrder(order_) != nullptr); - HWY_DASSERT(StringFromOut(out_) != nullptr); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); } @@ -450,7 +408,6 @@ class MMConfig { } MMOrder Order() const { return order_; } - MMOut Out() const { return out_; } // No `OuterTasks` because static partitioning across clusters is sufficient. size_t InnerTasks() const { return static_cast(inner_tasks_); } @@ -469,9 +426,8 @@ class MMConfig { uint32_t nc_multiple_; uint32_t kc_multiple_; MMOrder order_; - MMOut out_; uint8_t inner_tasks_; - HWY_MAYBE_UNUSED uint8_t reserved_[5]; + HWY_MAYBE_UNUSED uint8_t reserved_[6]; }; static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) @@ -691,11 +647,10 @@ struct MatMulEnv { // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV // writes to differing KV positions per query / output row. - // The first three allocations are sufficient for any A, B, C, respectively, - // but also potentially overwritten by each MatMul. Subsequent entries are - // precomputed for tensors and not overwritten. Per-tensor allocations make - // it likelier that asan detects bugs such as use after free, overrun, and - // dangling references. + // The first entry is sufficient for any C argument, but also potentially + // overwritten by each MatMul. Subsequent entries are precomputed for tensors + // and not overwritten. Per-tensor allocations make it likelier that asan + // detects bugs such as use after free, overrun, and dangling references. std::vector> row_ptrs; }; @@ -703,20 +658,14 @@ struct MatMulEnv { // Reduces register pressure compared to individual values/references. struct MMArgs { MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add, const StridedViewD& partial) - : env(&env), - per_key(&per_key), - scale(scale), - add(add), - partial(partial) {} + const float* HWY_RESTRICT add) + : env(&env), per_key(&per_key), scale(scale), add(add) {} MatMulEnv* env; MMPerKey* per_key; double scale; const float* HWY_RESTRICT add; - // Same size as C, threads write at false-sharing-free granularity. - StridedViewD partial; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. From 6c39a2dea417cde51895265f4a0d53e02fa7b48a Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 29 Aug 2025 03:18:28 -0700 Subject: [PATCH 12/65] 1.01x speedup: More bf16 activations to reduce DecompressA. Also move observer call into function, format gemma_args. PiperOrigin-RevId: 800827400 --- gemma/activations.h | 4 ++++ gemma/gemma.cc | 32 +++++++++++++++++--------------- gemma/gemma_args.h | 9 ++++++--- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 877afdf..175ddc9 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -155,6 +155,7 @@ struct Activations { : layer_config(config.layer_configs[0]), x(MatFactory("x", batch_size, config.model_dim, allocator)), + x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)), logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, @@ -173,6 +174,7 @@ struct Activations { // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. x.AllocateAndAttachRowPtrs(row_ptrs); + x_bf.AllocateAndAttachRowPtrs(row_ptrs); logits.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs); @@ -184,6 +186,7 @@ struct Activations { // Negligible CPU time. void SetBatchSize(size_t batch_size) { x.OverrideRows(batch_size); + x_bf.OverrideRows(batch_size); logits.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size); @@ -198,6 +201,7 @@ struct Activations { const LayerConfig& layer_config; MatStorageT x; // input + MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; // Gated FFW diff --git a/gemma/gemma.cc b/gemma/gemma.cc index c6da86c..fc1f238 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -294,6 +294,18 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, } } +static void MaybeObserve(const RuntimeConfig& runtime_config, + Activations& activations, QBatch& qbatch, + int layer_idx) { + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, + activations); + } + } +} + // Embeds PrevToken (one from each query) and calls each TransformerLayer. // Called by query-batched `PrefillQBatch` and `GenerateT`, but not the // token-batched `PrefillTBatch`, which supports image embedding. @@ -322,13 +334,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), activations, qbatch, env); - if constexpr (kObserver) { - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, - activations); - } - } + MaybeObserve(runtime_config, activations, qbatch, layer_idx); } } @@ -412,21 +418,17 @@ static void SampleAndStream( hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); - RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); + RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf, + env.ctx); - if constexpr (kObserver) { - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); - } - } + MaybeObserve(runtime_config, activations, qbatch, -1); { static const auto zone = env.ctx.profiler.AddZone( "Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive); PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone); // Compute logits from last layer activations. - CallMatMul(activations.x, weights.embedder_input_embedding, + CallMatMul(activations.x_bf, weights.embedder_input_embedding, /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 469ba2a..16c9595 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -220,9 +220,11 @@ struct InferenceArgs : public ArgsBase { "resets every turn)"); visitor(image_file, "image_file", Path(), "Image file to load."); - // Since it is not used in the CLI version, the print_verbosity is set higher than others. + // Since it is not used in the CLI version, the print_verbosity is set + // higher than others. visitor(port, "port", 8080, "Server port (default: 8080)", 3); - visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3); + visitor(model, "model", std::string("gemma3-4b"), + "Model name for API endpoints (default: gemma3-4b)", 3); visitor(prompt, "prompt", std::string(""), "Initial prompt for non-interactive mode. When specified, " @@ -282,7 +284,8 @@ struct ClientArgs : public ArgsBase { visitor(port, "port", 8080, "Server port (default: 8080)"); visitor(api_key, "api_key", std::string(""), - "Use public API with key (changes host to generativelanguage.googleapis.com:443)"); + "Use public API with key (changes host to " + "generativelanguage.googleapis.com:443)"); visitor(model, "model", std::string("gemma3-4b"), "Model name to use (default: gemma3-4b)"); visitor(prompt, "prompt", std::string("Hello! How are you?"), From 973e284ed6ee0595de90a399a52a869542563fc7 Mon Sep 17 00:00:00 2001 From: Marie White Date: Fri, 29 Aug 2025 05:40:06 -0700 Subject: [PATCH 13/65] Refactor Matmul to use a policy class for parallelization. PiperOrigin-RevId: 800864489 --- gemma/weights.cc | 4 +- ops/matmul-inl.h | 108 ++++++++++++++++++++++++++--------------------- ops/matmul.cc | 27 ++++++------ ops/matmul.h | 107 +++++++++++++++++++++++----------------------- 4 files changed, 125 insertions(+), 121 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 4124247..3425a60 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -320,8 +320,6 @@ static void AllocateAndBindAll(std::vector& tensors, const size_t start = owners.size(); owners.resize(start + tensors.size()); - MMParallel parallel(ctx); - // Allocate in parallel because faulting in large tensors is slow. ctx.pools.Pool().Run( 0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { @@ -339,7 +337,7 @@ static void AllocateAndBindAll(std::vector& tensors, owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, tensor.padding); - BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel); + BindB(ctx, *tensor.mat, tensor.mat->ElementBytes()); }); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 56cb06f..29be665 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -779,17 +779,19 @@ class MMPerPackage { // B and maybe A are decompressed several call layers lower, but not all // member functions depend on TA/TB, so pass them as an argument instead of // templating the class. - template - HWY_NOINLINE void operator()(const MatPtrT& A, const MatPtrT& B, + template + HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, + const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows) const { if constexpr (WantDecompressA()) { const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); - DecompressA(A, A_view); + DecompressA(A, A_view); constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. - DispatchOrder(A_view, A_padded, B, C_rows); + DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows); } else { const bool A_padded = HasPadding(A); - DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows); + DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B, + C_rows); } } @@ -828,28 +830,30 @@ class MMPerPackage { } // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. - template - HWY_INLINE void DispatchOrder(const StridedView A, const bool A_padded, + template + HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, + const StridedView A, const bool A_padded, const MatPtrT& B, RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(A, A_padded, B, C_rows); + return DoNT(parallel_policy, A, A_padded, B, C_rows); case MMOrder::kNT_K: - return DoNT_K(A, A_padded, B, C_rows); + return DoNT_K(parallel_policy, A, A_padded, B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(A, A_padded, B, C_rows); + return DoNT_MT(parallel_policy, A, A_padded, B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(A, A_padded, B, C_rows); + return DoNT_MT_K(parallel_policy, A, A_padded, B, C_rows); default: HWY_UNREACHABLE; } } // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -861,9 +865,9 @@ class MMPerPackage { Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); // Similar to `loop_nc` below, but here we hoisted `A_view`. - args_.env->parallel.ForNP( - range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, - [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + MMParallelPolicyT::ForNP( + args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -881,9 +885,10 @@ class MMPerPackage { } // Single M range, parallel N, sequential K. Sets C, then accumulates. - template - HWY_INLINE void DoNT_K(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); const IndexRange& range_mc = ranges_mc_.Range(0); @@ -909,9 +914,9 @@ class MMPerPackage { } }; - args_.env->parallel.ForNP( - range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, - [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + MMParallelPolicyT::ForNP( + args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -930,9 +935,10 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); @@ -942,8 +948,8 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. - args_.env->parallel.ForRangesMC_NC( - ranges_mc_, ranges_nc_, pkg_idx_, + MMParallelPolicyT::ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -965,9 +971,10 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); @@ -992,8 +999,8 @@ class MMPerPackage { out_tag, args_, C_rows); } }; // loop_nc - args_.env->parallel.ForRangesMC_NC( - ranges_mc_, ranges_nc_, pkg_idx_, + MMParallelPolicyT::ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -1014,6 +1021,7 @@ class MMPerPackage { } // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + template HWY_NOINLINE void DoDecompressA(const MatPtrT& A, const StridedViewBF A_view, MMParA par_a) const { @@ -1064,16 +1072,16 @@ class MMPerPackage { // line to avoid false sharing. const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); - args_.env->parallel.ForNP( - all_K, multiple_K, inner_tasks, pkg_idx_, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); + MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks, + pkg_idx_, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); break; } case MMParA::kM: - args_.env->parallel.ForRangeMC( - all_M, pkg_idx_, [&](size_t row_a, size_t worker) { + MMParallelPolicyT::ForRangeMC( + args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) { do_range(IndexRange(row_a, row_a + 1), all_K, worker); }); break; @@ -1081,12 +1089,13 @@ class MMPerPackage { } // Autotuning wrapper for `DoDecompressA`. + template HWY_INLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view) const { MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, *autotune.Best()); + return DoDecompressA(A, A_view, *autotune.Best()); } // First call: generate candidates. @@ -1099,7 +1108,7 @@ class MMPerPackage { const MMParA& par_a = autotune.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, par_a); + DoDecompressA(A, A_view, par_a); const uint64_t t1 = args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); @@ -1185,19 +1194,21 @@ struct MMImpl { if constexpr (kMaxPackages > 1) { // Outermost loop: static NUMA-aware partition of B rows across packages. - args.env->parallel.ForPkg( - args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { + MMNestedParallelPolicy::ForPkg( + args.env->ctx, args.per_key->ranges_np.NumTasks(), + [&](size_t pkg_idx) { MMZone mm_zone; mm_zone.MaybeEnter(pkg_idx, zone, args); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, - C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMNestedParallelPolicy(), A, B, C_rows); }); } else { const size_t pkg_idx = 0; HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMNestedParallelPolicy(), A, B, C_rows); } } }; @@ -1250,8 +1261,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // invalidates `MMAutoTune::Best()` index = env.per_key.size(); - env.per_key.push_back( - MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel)); + env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); } MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; diff --git a/ops/matmul.cc b/ops/matmul.cc index 71f2efe..711eac1 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -397,34 +397,33 @@ static size_t NPMultiple(const Allocator& allocator, size_t N, return np_multiple; } -IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, - size_t sizeof_TC, size_t nr) const { - const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages()); +IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, + size_t N, size_t sizeof_TC, size_t nr) { + const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages()); return StaticPartition( IndexRange(0, N), num_packages, - NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages)); + NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages)); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) - : ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) { +MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) { char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C } -void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { - Allocator& allocator = parallel.allocator(); +void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { + Allocator& allocator = ctx.allocator; if (!allocator.ShouldBind()) return; if (B.Rows() == 1) return; PROFILER_ZONE("Startup.BindB"); const IndexRangePartition ranges_np = - parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR); + MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR); for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); - const size_t node = parallel.Node(pkg_idx); + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); uintptr_t begin = reinterpret_cast(B.RowBytes(rows_b.begin())); uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); // B row padding is less than the page size, so only bind the subset that @@ -438,14 +437,14 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { } // C is BF16/float -void BindC(MatPtr& C, MMParallel& parallel) { - Allocator& allocator = parallel.allocator(); +void BindC(ThreadingContext& ctx, MatPtr& C) { + Allocator& allocator = ctx.allocator; if (!allocator.ShouldBind()) return; PROFILER_ZONE("Startup.BindC"); const IndexRangePartition ranges_np = - parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR); + MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR); bool ok = true; for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& cols_c = ranges_np.Range(pkg_idx); @@ -455,7 +454,7 @@ void BindC(MatPtr& C, MMParallel& parallel) { const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), allocator.BasePageBytes()); - const size_t node = parallel.Node(pkg_idx); + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); for (size_t im = 0; im < C.Rows(); ++im) { ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); } diff --git a/ops/matmul.h b/ops/matmul.h index e4c436f..16028f3 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -53,35 +53,31 @@ constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. static constexpr size_t kMaxMR = 4; -// Mostly stateless, can be constructed on the fly by weights.cc. Captures the -// the ThreadingContext to shorten call sites. -class MMParallel { - public: - // `ctx` must outlive this object. - MMParallel(ThreadingContext& ctx) : ctx_(ctx) { - if (ctx_.pools.NumPackages() > kMaxPackages) { - HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.", - ctx_.pools.NumPackages(), kMaxPackages); - } - } +IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, + size_t N, size_t sizeof_TC, size_t nr); - Allocator& allocator() const { return ctx_.allocator; } +enum class ParallelismType : uint8_t { + kNone, + // No parallelism. + kSequential, + // Parallelism at cluster level. + kCluster, + // Parallelism at package level. + kNested, +}; - // Initial static partitioning of B rows across packages. - IndexRangePartition RangesOfNP(size_t max_packages, size_t N, - size_t sizeof_TC, size_t nr) const; +struct MMOptions { + ParallelismType parallelism_type_ = ParallelismType::kNested; + uint8_t cluster_idx_ = 0; +}; - // For `BindB` and `BindC`. - size_t Node(size_t pkg_idx) const { - return ctx_.topology.GetCluster(pkg_idx, 0).Node(); - } - - // Calls `func(pkg_idx)` for each package in parallel. +struct MMNestedParallelPolicy { template - void ForPkg(const size_t max_packages, const Func& func) { + static void ForPkg(ThreadingContext& ctx, const size_t max_packages, + const Func& func) { if constexpr (kMaxPackages > 1) { - ctx_.pools.AllPackages().Run( - 0, HWY_MIN(max_packages, ctx_.pools.NumPackages()), + ctx.pools.AllPackages().Run( + 0, HWY_MIN(max_packages, ctx.pools.NumPackages()), [&](uint64_t task, size_t pkg_idx) { HWY_DASSERT(task == pkg_idx); (void)task; @@ -95,16 +91,17 @@ class MMParallel { // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. template - void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, - size_t pkg_idx, const Func& func) { + static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, + const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); + const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); // Single cluster: parallel-for over static partition of `range_np`. - hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0); const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); return ParallelizeOneRange( @@ -120,9 +117,9 @@ class MMParallel { ParallelizeOneRange( nx_ranges, all_clusters, [&](const IndexRange& nx_range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = - pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster(); + pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); @@ -137,18 +134,19 @@ class MMParallel { // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // rows). Calls `func(range_mc, range_nc, worker)`. template - void ForRangesMC_NC(const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, size_t pkg_idx, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); - hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); + static void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t pkg_idx, const Func& func) { + const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. const size_t num_clusters = all_clusters.NumWorkers(); // Single (big) cluster: collapse two range indices into one parallel-for // to reduce the number of fork-joins. if (num_clusters == 1) { const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( @@ -171,8 +169,8 @@ class MMParallel { ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { const size_t cluster_base = - pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster(); - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); + pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange(ranges_mc, cluster, [&](const IndexRange& range_mc, size_t thread) { func(range_mc, range_nc, cluster_base + thread); @@ -182,21 +180,18 @@ class MMParallel { // Calls `func(row_a, worker)` in parallel. template - void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); - ctx_.pools.Pool(pkg_idx).Run( + static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t pkg_idx, const Func& func) { + const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + ctx.pools.Pool(pkg_idx).Run( range_mc.begin(), range_mc.end(), [&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); }); } - - private: - ThreadingContext& ctx_; }; -void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); +void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); // C is BF16/float. -void BindC(MatPtr& C, MMParallel& parallel); +void BindC(ThreadingContext& ctx, MatPtr& C); // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. #pragma pack(push, 1) // power of two size @@ -250,15 +245,18 @@ class MMStorage { // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). - MMStorage(const Allocator& allocator, MMParallel& parallel) { + MMStorage(ThreadingContext& ctx) { // Per-package allocation so each can decompress A into its own copy. // Must be padded, see `DoDecompressA`. - parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) { + // Default to nested parallel policy. + MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { + Allocator& allocator = ctx.allocator; + pkg_A_[pkg_idx].reset(new MatStorageT( "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); if (allocator.ShouldBind()) { - const size_t node = parallel.Node(pkg_idx); + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * pkg_A_[pkg_idx]->ElementBytes(); bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); @@ -607,9 +605,9 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr, - MMParallel& parallel) - : ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) { + MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N, + size_t sizeof_TC, size_t nr) + : ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) { HWY_DASSERT(ranges_np.NumTasks() <= max_packages); } @@ -639,7 +637,6 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - MMParallel parallel; MMStorage storage; MMKeys keys; std::vector per_key; From 0ae8646731e3c80c56035083118c0d7310be20da Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 29 Aug 2025 07:25:14 -0700 Subject: [PATCH 14/65] Fix remainder handling for Paligemma No longer attempt to skip the remainder handling because B might also be a non-padded view. PiperOrigin-RevId: 800890805 --- compression/test_util-inl.h | 10 -- gemma/weights.cc | 3 - ops/matmul-inl.h | 202 ++++++++++++++---------------------- ops/matmul.h | 1 + ops/matmul_test.cc | 1 + 5 files changed, 79 insertions(+), 138 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 207b225..e5b1fe0 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -87,11 +87,6 @@ MatStorageT GenerateMat(const Extents2D& extents, Compress(raw.Row(r), raw.Cols(), ws.tls[thread], MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); - - // MatMul requires that A's padding be zero-initialized. - hwy::ZeroBytes( - compressed.Row(r) + extents.cols, - (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); compressed.SetScale(0.6f); // Arbitrary value, different from 1. @@ -120,11 +115,6 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, Compress(raw.Row(r), raw.Cols(), ws.tls[thread], MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); - - // MatMul requires that B's padding be zero-initialized. - hwy::ZeroBytes( - compressed.Row(r) + extents.cols, - (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); // Arbitrary value, different from 1, must match `GenerateMat`. diff --git a/gemma/weights.cc b/gemma/weights.cc index 3425a60..ca1cebc 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -444,9 +444,6 @@ static std::vector MakeBatches( HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); } offset += file_bytes_per_row; - // Must zero-initialize the in-memory row padding, see MatMul. - hwy::ZeroBytes(row_bytes + file_bytes_per_row, - mem_stride_bytes - file_bytes_per_row); row_bytes += mem_stride_bytes; } HWY_ASSERT(offset == range.End()); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 29be665..b54ce05 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -216,7 +216,7 @@ class MMKernel { // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. template - static HWY_INLINE void A2C0(const StridedView A_view, const bool A_padded, + static HWY_INLINE void A2C0(const StridedView A_view, const StridedViewBF& B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, @@ -229,8 +229,8 @@ class MMKernel { // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } return; } @@ -239,13 +239,13 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } return; } @@ -253,20 +253,18 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); imc += 1; } HWY_DASSERT(imc == mc); @@ -380,7 +378,6 @@ class MMKernel { // `B` is BF16, `A` and `C` can be F32 or BF16. template static HWY_INLINE void LoopKC(const StridedView A_view, - const bool A_padded, const StridedViewBF& B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { @@ -405,33 +402,8 @@ class MMKernel { const BF16* HWY_RESTRICT br2 = B_view.Row(2); const BF16* HWY_RESTRICT br3 = B_view.Row(3); - // Ensure `A` and `B` were zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - // Only check if `A` is padded, i.e. not packed. - if (A_padded) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) { - { - HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); - } - if constexpr (kRowsAC > 1) { - HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); - } - if constexpr (kRowsAC > 2) { - HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); - } - if constexpr (kRowsAC > 3) { - HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); - } - } - } - // B is unconditionally zero-padded by `DecompressAndZeroPad`. - for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(br0[i]) == 0.0f); - HWY_DASSERT(hwy::ConvertScalarTo(br1[i]) == 0.0f); - HWY_DASSERT(hwy::ConvertScalarTo(br2[i]) == 0.0f); - HWY_DASSERT(hwy::ConvertScalarTo(br3[i]) == 0.0f); - } - } + // Neither A nor B are guaranteed to be zero-padded: they might be a view + // into the left half. // Accumulate into f32. const hn::Repartition df; @@ -447,18 +419,18 @@ class MMKernel { // The loop step is always NBF: for non-native BF16 with TA=F32, this // entails 2x unrolling, which helps a little. const HWY_LANES_CONSTEXPR size_t kc_step = NBF; - // If A is packed (not padded), we have to check for remainders. Otherwise, - // we only run the main loop because A's padding is zero-initialized by - // `ZeroInit` or weights.cc. - const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc; - if (kc_end >= kc_step) { + if (kc >= kc_step) { HWY_UNROLL(1) - for (; ikc <= kc_end - kc_step; ikc += kc_step) { + for (; ikc <= kc - kc_step; ikc += kc_step) { if constexpr (HWY_NATIVE_DOT_BF16) { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + // NOTE: matmul_test has packed B so that it can call Span. The test + // cases with non-vector-multiple K require unaligned loads here. + // However, in actual usage, we should always have padded and thus + // aligned A and B. + const VBF b0 = hn::LoadU(dbf, br0 + ikc); + const VBF b1 = hn::LoadU(dbf, br1 + ikc); + const VBF b2 = hn::LoadU(dbf, br2 + ikc); + const VBF b3 = hn::LoadU(dbf, br3 + ikc); // Should only get here if `A` is BF16, otherwise `DecompressA` would // convert to BF16 and `A_view` points to that. @@ -491,10 +463,10 @@ class MMKernel { // shuffles. VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadU(dbf, br0 + ikc); + const VBF b1 = hn::LoadU(dbf, br1 + ikc); + const VBF b2 = hn::LoadU(dbf, br2 + ikc); + const VBF b3 = hn::LoadU(dbf, br3 + ikc); b0e = hn::PromoteEvenTo(df, b0); b1e = hn::PromoteEvenTo(df, b1); b2e = hn::PromoteEvenTo(df, b2); @@ -523,10 +495,10 @@ class MMKernel { } } else { // IsF32(): promote BF to 2xF32, F32*F32. // Full-vector loads are a bit faster on SKX than half + PromoteTo. - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadU(dbf, br0 + ikc); + const VBF b1 = hn::LoadU(dbf, br1 + ikc); + const VBF b2 = hn::LoadU(dbf, br2 + ikc); + const VBF b3 = hn::LoadU(dbf, br3 + ikc); const VF b00 = hn::PromoteLowerTo(df, b0); const VF b10 = hn::PromoteLowerTo(df, b1); const VF b20 = hn::PromoteLowerTo(df, b2); @@ -586,17 +558,16 @@ class MMKernel { } } - // We want the number of actual valid kc, but we may already be beyond `kc`. - const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc; + // Always handle remainders: even though A and B are generally padded, we + // might have a view into the left half of A and/or B. + const size_t remaining_kc = kc - ikc; HWY_DASSERT(remaining_kc < kc_step); - HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0)); - // Last iteration: B is padded but A is not; guard its loads. if (HWY_UNLIKELY(remaining_kc != 0)) { if constexpr (HWY_NATIVE_DOT_BF16) { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); + const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); + const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); + const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); // Should only get here if `A` is BF16, otherwise `DecompressA` would // convert to BF16 and `A_view` points to that. @@ -628,10 +599,10 @@ class MMKernel { // lane-crossing promotion for both might be bottlenecked on shuffles. VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); + const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); + const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); + const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); b0e = hn::PromoteEvenTo(df, b0); b1e = hn::PromoteEvenTo(df, b1); b2e = hn::PromoteEvenTo(df, b2); @@ -661,10 +632,10 @@ class MMKernel { C33); } } else { // IsF32(): promote half-B to F32, F32*F32. - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); + const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); + const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); + const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); const VF b00 = hn::PromoteLowerTo(df, b0); const VF b10 = hn::PromoteLowerTo(df, b1); const VF b20 = hn::PromoteLowerTo(df, b2); @@ -786,12 +757,9 @@ class MMPerPackage { if constexpr (WantDecompressA()) { const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); DecompressA(A, A_view); - constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. - DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows); + DispatchOrder(parallel_policy, A_view, B, C_rows); } else { - const bool A_padded = HasPadding(A); - DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B, - C_rows); + DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); } } @@ -808,42 +776,29 @@ class MMPerPackage { return HWY_MAX(kNR, line_bytes_ / sizeof_TC); } - // Use instead of `MatPtr::IsPacked` because that returns true for single - // rows, but we want to know whether there is padding. - static bool HasPadding(const MatPtr& mat) { - return mat.Stride() > mat.Cols(); - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A`` - // and `B` are const, but StridedView is also used for non-const `partial`. + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. template static StridedView View(const MatPtrT& AB, size_t r, size_t c, size_t cols) { HWY_DASSERT(c < AB.Cols()); HWY_DASSERT(cols <= AB.Cols() - c); - HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag()); - (void)N; - // If `AB` is padded, then `LoopKC` expects the view is either a vector - // multiple, or all columns and thus also padded. - HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols())); return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); } // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. template HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, - const StridedView A, const bool A_padded, - const MatPtrT& B, + const StridedView A, const MatPtrT& B, RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(parallel_policy, A, A_padded, B, C_rows); + return DoNT(parallel_policy, A, B, C_rows); case MMOrder::kNT_K: - return DoNT_K(parallel_policy, A, A_padded, B, C_rows); + return DoNT_K(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(parallel_policy, A, A_padded, B, C_rows); + return DoNT_MT(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(parallel_policy, A, A_padded, B, C_rows); + return DoNT_MT_K(parallel_policy, A, B, C_rows); default: HWY_UNREACHABLE; } @@ -852,8 +807,7 @@ class MMPerPackage { // Single M and K ranges, parallel N. Fills all of C directly. template HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -878,8 +832,8 @@ class MMPerPackage { row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K, - MMSetC(), args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), + args_, C_rows); } }); } @@ -887,8 +841,7 @@ class MMPerPackage { // Single M range, parallel N, sequential K. Sets C, then accumulates. template HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); const IndexRange& range_mc = ranges_mc_.Range(0); @@ -909,8 +862,8 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, - out_tag, args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, + C_rows); } }; @@ -937,8 +890,7 @@ class MMPerPackage { // Fills `mc x nc` sections of C directly, in parallel. template HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); @@ -963,8 +915,8 @@ class MMPerPackage { row_b += kNR) { const StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, - MMSetC(), args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), + args_, C_rows); } }); } @@ -973,8 +925,7 @@ class MMPerPackage { // Fills `mc x nc` sections of `partial`, then `C`, in parallel. template HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); @@ -995,8 +946,8 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, - out_tag, args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, + C_rows); } }; // loop_nc MMParallelPolicyT::ForRangesMC_NC( @@ -1129,12 +1080,9 @@ class MMPerPackage { const hn::ScalableTag dbf; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); - // View() is safe if vector multiple, or padded: for the latter, `ZeroInit` - // and weights.cc zero-initialize the padding. + // Neither A nor B require padding because `LoopKC` handles remainders. if constexpr (hwy::IsSame()) { - if (B.Cols() % NBF == 0 || HasPadding(B)) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); - } + return View(B, row_b, range_kc.begin(), range_kc.Num()); } const PackedSpan B_span = B.PaddedSpan(); @@ -1219,11 +1167,7 @@ struct MMImpl { // `K = B.Cols()`, which must match `A.Cols()`, is the number // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // are no other restrictions on shape, though performance is better when `M % 4 -// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()). -// -// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(), -// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match -// the behavior of `DecompressAndZeroPad`. We check this in debug builds. +// == 0` or `M <= 4`. // // If `add` is non-null, the row-vector `add` is added to each of the `M` rows // of `C`, which is a row-major matrix with arbitrary stride. A scale for @@ -1282,6 +1226,14 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(N % kNR == 0); + // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are + // reliable: the latter returns true for single rows, and the former may + // match `Cols` if the width matches the padding. + // Note that B is packed in matmul_test, but otherwise generally padded. + HWY_ASSERT(hwy::IsAligned(A.Row(0), env.ctx.allocator.LineBytes())); + if (A.Rows() > 1) { + HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes())); + } tuner.SetCandidates( MMCandidates(allocator, M, K, N, MMPerPackage::ABytes(), sizeof(TC), diff --git a/ops/matmul.h b/ops/matmul.h index 16028f3..752bad1 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -194,6 +194,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); void BindC(ThreadingContext& ctx, MatPtr& C); // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. +// Also used to decompress B, hence non-const. #pragma pack(push, 1) // power of two size template class StridedView { diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index aadbc56..14913a1 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -318,6 +318,7 @@ void TestAllMatMul() { ThreadingArgs threading_args; threading_args.bind = Tristate::kTrue; + ThreadingContext ctx(threading_args); MatMulEnv env(ctx); NestedPools& pools = env.ctx.pools; From 00b70f69c5785c416b825e3ffd10e9036c057811 Mon Sep 17 00:00:00 2001 From: Marie White Date: Fri, 29 Aug 2025 08:04:05 -0700 Subject: [PATCH 15/65] Include parallelism type in DoMatMul. Also remove package handling. PiperOrigin-RevId: 800902568 --- ops/matmul-inl.h | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index b54ce05..8f91114 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1135,28 +1135,23 @@ struct MMImpl { template static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows, const MMArgs& args, - const MMConfig& config) { + const MMConfig& config, + ParallelismType parallelism_type) { PROFILER_ZONE("MM.DoMatMul"); - static const auto zone = - args.env->ctx.profiler.AddZone("MM.DoMatMul.PerPkg"); + const size_t pkg_idx = 0; + HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); + const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - if constexpr (kMaxPackages > 1) { - // Outermost loop: static NUMA-aware partition of B rows across packages. - MMNestedParallelPolicy::ForPkg( - args.env->ctx, args.per_key->ranges_np.NumTasks(), - [&](size_t pkg_idx) { - MMZone mm_zone; - mm_zone.MaybeEnter(pkg_idx, zone, args); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMNestedParallelPolicy(), A, B, C_rows); - }); - } else { - const size_t pkg_idx = 0; - HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMNestedParallelPolicy(), A, B, C_rows); + switch (parallelism_type) { + case ParallelismType::kNested: + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMNestedParallelPolicy(), A, B, C_rows); + break; + case ParallelismType::kNone: + case ParallelismType::kSequential: + case ParallelismType::kCluster: + HWY_ABORT("Parallelism type not implemented."); + break; } } }; @@ -1210,10 +1205,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; + // Default to nested parallelism. + const ParallelismType parallelism_type = ParallelismType::kNested; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best()); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type); return &per_key; } @@ -1242,7 +1239,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg); + MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / From bc0c0bac8b62f858fe9c28b4e6f3e39e4a07c5c2 Mon Sep 17 00:00:00 2001 From: Marie White Date: Fri, 29 Aug 2025 08:38:19 -0700 Subject: [PATCH 16/65] Add non-threading parallel policy. PiperOrigin-RevId: 800913294 --- ops/matmul.h | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/ops/matmul.h b/ops/matmul.h index 752bad1..dda9673 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -71,6 +71,48 @@ struct MMOptions { uint8_t cluster_idx_ = 0; }; +struct MMSequentialPolicy { + template + static void ForPkg(ThreadingContext& ctx, const size_t max_packages, + const Func& func) { + func(/*pkg_idx=*/0); + } + + template + static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, + const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + func(range_np, base_idx); + } + + template + static void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t pkg_idx, const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + + for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { + const IndexRange range_mc = ranges_mc.Range(i); + for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) { + const IndexRange range_nc = ranges_nc.Range(j); + func(range_mc, range_nc, base_idx); + } + } + } + + template + static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t pkg_idx, const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { + func(row_a, base_idx); + } + } +}; + struct MMNestedParallelPolicy { template static void ForPkg(ThreadingContext& ctx, const size_t max_packages, From 229bd078a14e0c33afb5bfe8bff9005479e52938 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 1 Sep 2025 06:32:24 -0700 Subject: [PATCH 17/65] 1.29x speedup: bf16 C1/C2. Extend most ops to any type, expand test coverage. Also increase dot_test.cc range for Zen4, and matmul_test tolerance (failing in some configs) PiperOrigin-RevId: 801789922 --- BUILD.bazel | 1 + compression/compress-inl.h | 236 +++++++++++++++++ compression/compress_test.cc | 87 ++++++- compression/test_util-inl.h | 29 +++ compression/types.h | 5 + gemma/activations.h | 5 +- gemma/attention.cc | 6 +- gemma/gemma-inl.h | 27 +- gemma/vit.cc | 2 +- ops/dot_test.cc | 2 +- ops/matmul_test.cc | 2 +- ops/ops-inl.h | 336 +++++++----------------- ops/ops_test.cc | 479 ++++++++++++++++++++++------------- 13 files changed, 761 insertions(+), 456 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 1f9e210..e141e95 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -376,6 +376,7 @@ cc_test( ":test_util", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep + "//compression:test_util", "//compression:types", "@highway//:hwy", "@highway//:hwy_test_util", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 512f8fa..e2b3bcc 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -709,6 +709,242 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan v, comp3); } +// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h. +// `DF` is the decompressed type, typically `float`. +template +HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout, + size_t num, Func&& func) { + const auto packed_inout = MakeSpan(inout, num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v0, v1; + Decompress2(df, packed_inout, i, v0, v1); + const VF out0 = func(df, v0); + const VF out1 = func(df, v1); + Compress2(df, out0, out1, packed_inout, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_inout + NF); + DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining); + const VF v0 = hn::Load(df, buf_inout); + const VF v1 = hn::Load(df, buf_inout + NF); + const VF out0 = func(df, v0); + const VF out1 = func(df, v1); + Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + inout[i + j] = hwy::ConvertScalarTo(buf_inout[j]); + } + } +} + +// One extra argument. `DF` is the decompressed type, typically `float`. +template +HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, + size_t num, + const T1* HWY_RESTRICT p1, + Func&& func) { + const auto packed_inout = MakeSpan(inout, num); + const auto packed1 = MakeSpan(p1, num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v0, v1; + Decompress2(df, packed_inout, i, v0, v1); + VF v10, v11; + Decompress2(df, packed1, i, v10, v11); + const VF out0 = func(df, v0, v10); + const VF out1 = func(df, v1, v11); + Compress2(df, out0, out1, packed_inout, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf1[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_inout + NF); + hn::Store(hn::Zero(df), df, buf1 + NF); + DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining); + DecompressAndZeroPad(df, packed1, i, buf1, remaining); + const VF v0 = hn::Load(df, buf_inout); + const VF v1 = hn::Load(df, buf_inout + NF); + const VF v10 = hn::Load(df, buf1); + const VF v11 = hn::Load(df, buf1 + NF); + const VF out0 = func(df, v0, v10); + const VF out1 = func(df, v1, v11); + Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + inout[i + j] = hwy::ConvertScalarTo(buf_inout[j]); + } + } +} + +// Single input, separate output. `DF` is the decompressed type, typically +// `float`. +template +HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, + const T1* HWY_RESTRICT p1, + Func&& func) { + const auto packed_out = MakeSpan(out, num); + const auto packed1 = MakeSpan(p1, num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v10, v11; + Decompress2(df, packed1, i, v10, v11); + const VF out0 = func(df, v10); + const VF out1 = func(df, v11); + Compress2(df, out0, out1, packed_out, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf1[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf1 + NF); + DecompressAndZeroPad(df, packed1, i, buf1, remaining); + const VF v10 = hn::Load(df, buf1); + const VF v11 = hn::Load(df, buf1 + NF); + const VF out0 = func(df, v10); + const VF out1 = func(df, v11); + Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + out[i + j] = hwy::ConvertScalarTo(buf_out[j]); + } + } +} + +// Two inputs. `DF` is the decompressed type, typically `float`. +template +HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, + const T1* HWY_RESTRICT p1, + const T2* HWY_RESTRICT p2, + Func&& func) { + const auto packed_out = MakeSpan(out, num); + const auto packed1 = MakeSpan(p1, num); + const auto packed2 = MakeSpan(p2, num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v10, v11, v20, v21; + Decompress2(df, packed1, i, v10, v11); + Decompress2(df, packed2, i, v20, v21); + const VF out0 = func(df, v10, v20); + const VF out1 = func(df, v11, v21); + Compress2(df, out0, out1, packed_out, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf1[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf2[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf1 + NF); + hn::Store(hn::Zero(df), df, buf2 + NF); + DecompressAndZeroPad(df, packed1, i, buf1, remaining); + DecompressAndZeroPad(df, packed2, i, buf2, remaining); + const VF v10 = hn::Load(df, buf1); + const VF v11 = hn::Load(df, buf1 + NF); + const VF v20 = hn::Load(df, buf2); + const VF v21 = hn::Load(df, buf2 + NF); + const VF out0 = func(df, v10, v20); + const VF out1 = func(df, v11, v21); + Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + out[i + j] = hwy::ConvertScalarTo(buf_out[j]); + } + } +} + +// Three inputs. `DF` is the decompressed type, typically `float`. +template +HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, + const T1* HWY_RESTRICT p1, + const T2* HWY_RESTRICT p2, + const T3* HWY_RESTRICT p3, + Func&& func) { + const auto packed_out = MakeSpan(out, num); + const auto packed1 = MakeSpan(p1, num); + const auto packed2 = MakeSpan(p2, num); + const auto packed3 = MakeSpan(p3, num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v10, v11, v20, v21, v30, v31; + Decompress2(df, packed1, i, v10, v11); + Decompress2(df, packed2, i, v20, v21); + Decompress2(df, packed3, i, v30, v31); + const VF out0 = func(df, v10, v20, v30); + const VF out1 = func(df, v11, v21, v31); + Compress2(df, out0, out1, packed_out, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf1[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf2[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf3[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf1 + NF); + hn::Store(hn::Zero(df), df, buf2 + NF); + hn::Store(hn::Zero(df), df, buf3 + NF); + DecompressAndZeroPad(df, packed1, i, buf1, remaining); + DecompressAndZeroPad(df, packed2, i, buf2, remaining); + DecompressAndZeroPad(df, packed3, i, buf3, remaining); + const VF v10 = hn::Load(df, buf1); + const VF v11 = hn::Load(df, buf1 + NF); + const VF v20 = hn::Load(df, buf2); + const VF v21 = hn::Load(df, buf2 + NF); + const VF v30 = hn::Load(df, buf3); + const VF v31 = hn::Load(df, buf3 + NF); + const VF out0 = func(df, v10, v20, v30); + const VF out1 = func(df, v11, v21, v31); + Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + out[i + j] = hwy::ConvertScalarTo(buf_out[j]); + } + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 2270689..5455b1d 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -18,11 +18,10 @@ #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS -#include "compression/compress.h" - #include #include +#include "compression/compress.h" #include "compression/distortion.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" @@ -45,7 +44,7 @@ namespace hn = hwy::HWY_NAMESPACE; // Calls Compress and Decompress2 and verifies the distortion/error. template -struct TestDecompress2T { +struct TestDecompress2 { template HWY_INLINE void operator()(T /*unused*/, D d) { const size_t N = hn::Lanes(d); @@ -120,12 +119,12 @@ struct TestDecompress2T { } }; -void TestAllDecompress2() { ForeachPackedAndRawType(); } +void TestAllDecompress2() { ForeachPackedAndRawType(); } // Calls Compress and DecompressAndZeroPad for all short lengths and verifies // the distortion/error. template -struct TestShortLengthsT { +struct TestShortLengths { template HWY_INLINE void operator()(T /*unused*/, D d) { const size_t N = hn::Lanes(d); @@ -196,7 +195,82 @@ struct TestShortLengthsT { } }; -void TestAllShortLengths() { ForeachPackedAndRawType(); } +void TestAllShortLengths() { ForeachPackedAndRawType(); } + +// Verifies the arguments and remainder handling of `DecompressAndCompress*`. +class TestDecompressAndCompress { + public: + template + HWY_INLINE void operator()(T /*unused*/, D d) { + ForeachActivationType3(d); + } + + private: + struct Test { + template + void operator()(T1, T2, T3, D d) { + hwy::RandomState rng; + using DF = hn::Repartition; + using VF = hn::Vec; + const DF df; + + for (size_t num = 1; num < 7 * hn::Lanes(d); ++num) { + auto p = hwy::AllocateAligned(num); + auto p1 = hwy::AllocateAligned(num); + auto p2 = hwy::AllocateAligned(num); + auto out = hwy::AllocateAligned(num); + auto expected1 = hwy::AllocateAligned(num); + auto expected2 = hwy::AllocateAligned(num); + auto expected3 = hwy::AllocateAligned(num); + HWY_ASSERT(p && p1 && p2 && out && expected1 && expected2 && expected3); + // Two bits each, totalling 6 bits which fit in the BF16 mantissa. + for (size_t i = 0; i < num; ++i) { + const size_t mod = i & 3; + p[i] = hwy::ConvertScalarTo(mod); + p1[i] = hwy::ConvertScalarTo(mod << 2); + p2[i] = hwy::ConvertScalarTo(mod << 4); + // For `Decompress1AndCompressInplace` to not overwrite `p`. + out[i] = p[i]; + expected1[i] = hwy::ConvertScalarTo(mod); + expected2[i] = hwy::ConvertScalarTo((mod << 2) | mod); + expected3[i] = + hwy::ConvertScalarTo((mod << 4) | (mod << 2) | mod); + } + + DecompressAndCompressInplace(df, p.get(), num, + [](DF, VF v) HWY_ATTR -> VF { return v; }); + HWY_ASSERT_ARRAY_EQ(expected1.get(), p.get(), num); + + // Uses `out` so as not to overwrite `p`. + Decompress1AndCompressInplace( + df, out.get(), num, p1.get(), + [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); + HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); + + Decompress1AndCompressTo(df, out.get(), num, p.get(), + [](DF, VF v) HWY_ATTR -> VF { return v; }); + HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num); + + Decompress2AndCompressTo(df, out.get(), num, p.get(), p1.get(), + [](DF, VF v, VF v1) + HWY_ATTR -> VF { return hn::Add(v, v1); }); + HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); + + Decompress3AndCompressTo( + df, out.get(), num, p.get(), p1.get(), p2.get(), + [](DF, VF v, VF v1, VF v2) + HWY_ATTR -> VF { return hn::Add(hn::Add(v, v1), v2); }); + HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num); + } + } + }; +}; + +void TestAllDecompressAndCompress() { + // The Highway Test interface (`ForGE128Vectors`) only supports a single type. + // We hard-code one here, and use `ForeachActivationType` internally. + hn::ForGE128Vectors()(float()); +} // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -208,6 +282,7 @@ namespace gcpp { HWY_BEFORE_TEST(CompressTest); HWY_EXPORT_AND_TEST_P(CompressTest, TestAllDecompress2); HWY_EXPORT_AND_TEST_P(CompressTest, TestAllShortLengths); +HWY_EXPORT_AND_TEST_P(CompressTest, TestAllDecompressAndCompress); HWY_AFTER_TEST(); } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index e5b1fe0..1c72b32 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -67,6 +67,35 @@ void ForeachPackedAndRawType() { } } +template +void ForeachActivationType1(D d) { + Test test; + test(float(), d); + test(BF16(), d); +} + +template +void ForeachActivationType2(D d) { + Test test; + test(float(), float(), d); + test(float(), BF16(), d); + test(BF16(), float(), d); + test(BF16(), BF16(), d); +} + +template +void ForeachActivationType3(D d) { + Test test; + test(float(), float(), float(), d); + test(float(), float(), BF16(), d); + test(float(), BF16(), float(), d); + test(float(), BF16(), BF16(), d); + test(BF16(), float(), float(), d); + test(BF16(), float(), BF16(), d); + test(BF16(), BF16(), float(), d); + test(BF16(), BF16(), BF16(), d); +} + // Generates inputs: deterministic, within max SfpStream range. template MatStorageT GenerateMat(const Extents2D& extents, diff --git a/compression/types.h b/compression/types.h index dc10676..667265a 100644 --- a/compression/types.h +++ b/compression/types.h @@ -186,6 +186,11 @@ constexpr bool IsNuqStream() { return hwy::IsSame, NuqStream>(); } +template +constexpr bool SupportsPointerArithmetic() { + return !IsNuqStream(); +} + // Tensor types for loading weights. enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 }; // These are used in `ModelConfig.Specifier`, hence the strings will not diff --git a/gemma/activations.h b/gemma/activations.h index 175ddc9..14994d3 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -206,9 +206,8 @@ struct Activations { // Gated FFW MatStorageT pre_ffw_rms_out; - // Norm may be large, so prefer to keep as f32. - MatStorageT C1; - MatStorageT C2; + MatStorageT C1; + MatStorageT C2; MatStorageT ffw_out; AttentionActivations attention; diff --git a/gemma/attention.cc b/gemma/attention.cc index c73abcb..13681d0 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -144,8 +144,8 @@ void SingleDotSoftmaxWeightedSum( // Apply rope and scaling to Q. if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, q, - layer.layer_config.qkv_dim, p, worker); + RMSNormInplace(weights_t->PackedScale1(), q, layer.layer_config.qkv_dim, + p, worker); }); } @@ -307,7 +307,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim, + RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim, env.ctx.profiler, thread); }); } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index c1ff722..ed0750a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -43,14 +43,14 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -void Activation(ActivationType activation, T* HWY_RESTRICT c1, - const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, +template +void Activation(ActivationType activation, T1* HWY_RESTRICT c1, + const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, const size_t worker) { static const auto zone = p.AddZone("Gen.Activation"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; + using DF = hn::ScalableTag; using VF = hn::Vec; // ActivationType::Gelu if (c2 == nullptr) { // No multiplier, just Gelu. @@ -58,9 +58,10 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1, return; }; // Has multiplier, Gelu(c1) * c2. - hn::Transform1(DF(), c1, count, c2, [](DF df, VF v, VF mul) HWY_ATTR { - return hn::Mul(mul, Gelu(df, v)); - }); + Decompress1AndCompressInplace(DF(), c1, count, c2, + [](DF df, VF v1, VF v2) HWY_ATTR -> VF { + return hn::Mul(v2, Gelu(df, v1)); + }); } // No C2 multiplier. @@ -75,10 +76,9 @@ void ActivationBatched(ActivationType activation, Mat& c1, }); } -template -HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, - const Mat* c2, ThreadingContext& ctx) { - using T = typename Mat::T; +template +HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat1& c1, + const Mat2* c2, ThreadingContext& ctx) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { @@ -87,8 +87,9 @@ HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, }); } else { // No multiplier SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), ctx.profiler, worker); + Activation(activation, c1.Row(task), + static_cast(nullptr), c1.Cols(), + ctx.profiler, worker); }); } } diff --git a/gemma/vit.cc b/gemma/vit.cc index a694187..96d6d7f 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -335,7 +335,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), + RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0), vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread()); }); } diff --git a/ops/dot_test.cc b/ops/dot_test.cc index a461614..2c0ae3a 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -750,7 +750,7 @@ class DotStats { void CheckMuls() const { // Comp2 is between Compensated and Kahan. ASSERT_INSIDE(kComp2, 1.001, s_muls[kComp2].Mean(), 1.4); - ASSERT_INSIDE(kComp2, 1.001f, s_muls[kComp2].Max(), 2.4f); + ASSERT_INSIDE(kComp2, 1.001f, s_muls[kComp2].Max(), 6.8f); ASSERT_INSIDE(kComp2, 1.0, s_muls[kComp2].GeometricMean(), 1.2); // Compensated and Double are very accurate. diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 14913a1..3a8528a 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -119,7 +119,7 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); // Dot() uses double-precision summation. - double tolerance = 12 * norm * eps_f32; + double tolerance = 20 * norm * eps_f32; // If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to // BF16, so add extra tolerance. if (IsF32()) { diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 8a7224b..caf8041 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -127,12 +127,13 @@ HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { } // Activation already has a profiler zone. -static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, - size_t size) { +template +static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(T* HWY_RESTRICT x, size_t size) { namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - hn::Transform(D(), x, size, - [](D d, hn::Vec v) HWY_ATTR { return Gelu(d, v); }); + using DF = hn::ScalableTag; + using VF = hn::Vec; + DecompressAndCompressInplace( + DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return Gelu(d, v); }); } template @@ -179,13 +180,15 @@ HWY_INLINE hn::Vec Sigmoid(D d, hn::Vec v) { } // Sigmoid using the logistic function 1 / (1 + exp(-x[i])) -static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x, +template +static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(T* HWY_RESTRICT x, size_t size) { PROFILER_ZONE("ops.Sigmoid"); namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - hn::Transform(D(), x, size, - [](D d, hn::Vec v) HWY_ATTR { return Sigmoid(d, v); }); + using DF = hn::ScalableTag; + using VF = hn::Vec; + DecompressAndCompressInplace( + DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return Sigmoid(d, v); }); } namespace detail { @@ -205,71 +208,53 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, } // namespace detail -// `x_ofs` is the offset within `x`, required for NuqStream. template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, - size_t w_ofs, OT* HWY_RESTRICT out, + OT* HWY_RESTRICT out, const size_t size, hwy::Profiler& p, const size_t worker) { static const auto zone = p.AddZone("Ops.RMSNorm"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - using VF = hn::Vec; - const size_t NF = hn::Lanes(df); + using DF = hn::ScalableTag; + using VF = hn::Vec; - const VF mul = hn::Set(df, detail::RMSNormMul(x, size, p, worker)); + const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker)); + const VF* HWY_RESTRICT pmul = &mul; - const auto packed_x = MakeSpan(x, size); - const auto packed_w = MakeSpan(weight, w_ofs + size); - const auto packed_out = MakeSpan(out, size); - - HWY_DASSERT(size % (2 * NF) == 0); - for (size_t i = 0; i < size; i += 2 * NF) { - VF x0, x1, w0, w1; - Decompress2(df, packed_x, i, x0, x1); - Decompress2(df, packed_w, w_ofs + i, w0, w1); - const VF m0 = hn::Mul(mul, x0); - const VF m1 = hn::Mul(mul, x1); - // (1+weight) * m = m + weight*m = one FMA. - const VF out0 = hn::MulAdd(m0, w0, m0); - const VF out1 = hn::MulAdd(m1, w1, m1); - Compress2(df, out0, out1, packed_out, i); - } + Decompress2AndCompressTo(DF(), out, size, x, weight, + [pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF { + const VF m = hn::Mul(*pmul, vx); + // (1+weight) * m = m + weight*m = one FMA. + return hn::MulAdd(m, vw, m); + }); } // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( - const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, - const size_t size, hwy::Profiler& p, const size_t worker) { +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, + XT* HWY_RESTRICT inout, + const size_t size, + hwy::Profiler& p, + const size_t worker) { static const auto zone = p.AddZone("Ops.RMSNormInplace"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - using VF = hn::Vec; - const size_t NF = hn::Lanes(df); + using DF = hn::ScalableTag; + using VF = hn::Vec; - const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, p, worker)); + const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker)); + const VF* HWY_RESTRICT pmul = &mul; - const auto packed_w = MakeSpan(weight, w_ofs + size); - const auto packed_x = MakeSpan(inout, size); - - HWY_DASSERT(size % (2 * NF) == 0); - for (size_t i = 0; i < size; i += 2 * NF) { - VF x0, x1, w0, w1; - Decompress2(df, packed_x, i, x0, x1); - Decompress2(df, packed_w, w_ofs + i, w0, w1); - const VF m0 = hn::Mul(mul, x0); - const VF m1 = hn::Mul(mul, x1); - // (1+weight) * m = m + weight*m = one FMA. - const VF out0 = hn::MulAdd(m0, w0, m0); - const VF out1 = hn::MulAdd(m1, w1, m1); - Compress2(df, out0, out1, packed_x, i); - } + Decompress1AndCompressInplace(DF(), inout, size, weight, + [pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF { + const VF m = hn::Mul(*pmul, vx); + // (1+weight) * m = m + weight*m = one FMA. + return hn::MulAdd(m, vw, m); + }); } // Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm. @@ -301,9 +286,9 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale, PROFILER_ZONE("ops.LayerNorm"); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - using VF = hn::Vec; - const size_t NF = hn::Lanes(df); + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; double mu, mu2; ComputeMoments(x, size, mu, mu2); @@ -315,56 +300,13 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale, const VF* HWY_RESTRICT pmu = &vmu; const VF* HWY_RESTRICT pvar = &vvar; - const auto packed_x = MakeSpan(x, size); - const auto packed_scale = MakeSpan(scale, size); - const auto packed_bias = MakeSpan(bias, size); - const auto packed_out = MakeSpan(out, size); - - // Loop body for one vector, called from main loop and remainder loop. - const auto norm = [pmu, pvar](VF x, VF s, VF add) HWY_ATTR -> VF { - const VF centered = hn::Sub(x, *pmu); - const VF mul = hn::Mul(s, *pvar); - return hn::MulAdd(centered, mul, add); - }; - - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - VF x0, x1, s0, s1, add0, add1; - Decompress2(df, packed_x, i, x0, x1); - Decompress2(df, packed_scale, i, s0, s1); - Decompress2(df, packed_bias, i, add0, add1); - const VF n0 = norm(x0, s0, add0); - const VF n1 = norm(x1, s1, add1); - Compress2(df, n0, n1, packed_out, i); - } - } - - const size_t remaining = size - i; - HWY_DASSERT(remaining < 2 * NF); - if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; - HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)]; - HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)]; - // Ensure the second vectors are zeroed even if remaining <= NF. - hn::Store(hn::Zero(df), df, buf_x + NF); - hn::Store(hn::Zero(df), df, buf_scale + NF); - hn::Store(hn::Zero(df), df, buf_bias + NF); - HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)]; - DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); - DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining); - DecompressAndZeroPad(df, packed_bias, i, buf_bias, remaining); - const VF x0 = hn::Load(df, buf_x); - const VF x1 = hn::Load(df, buf_x + NF); - const VF s0 = hn::Load(df, buf_scale); - const VF s1 = hn::Load(df, buf_scale + NF); - const VF add0 = hn::Load(df, buf_bias); - const VF add1 = hn::Load(df, buf_bias + NF); - const VF n0 = norm(x0, s0, add0); - const VF n1 = norm(x1, s1, add1); - Compress2(df, n0, n1, MakeSpan(buf_out, 2 * NF), 0); - hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT)); - } + Decompress3AndCompressTo(DF(), out, size, x, scale, bias, + [pmu, pvar](DF /*df*/, VF x, VF s, VF add) + HWY_ATTR -> VF { + const VF centered = hn::Sub(x, *pmu); + const VF mul = hn::Mul(s, *pvar); + return hn::MulAdd(centered, mul, add); + }); } static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( @@ -541,40 +483,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - const size_t NF = hn::Lanes(df); - using VF = hn::Vec; - - const auto packed_x = MakeSpan(x, size); - - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - VF x0, x1; - Decompress2(df, packed_x, i, x0, x1); - VF out0 = hn::Load(df, out + i); - VF out1 = hn::Load(df, out + i + NF); - hn::Store(hn::Add(x0, out0), df, out + i); - hn::Store(hn::Add(x1, out1), df, out + i + NF); - } - } - - const size_t remaining = size - i; - const size_t remaining1 = remaining - HWY_MIN(remaining, NF); - HWY_DASSERT(remaining < 2 * NF); - HWY_DASSERT(remaining1 < NF); - if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; - // Ensure the second vector is zeroed even if remaining <= NF. - hn::Store(hn::Zero(df), df, buf_x + NF); - DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); - const VF x0 = hn::Load(df, buf_x); - const VF x1 = hn::Load(df, buf_x + NF); - const VF out0 = hn::LoadN(df, out + i, remaining); - const VF out1 = hn::LoadN(df, out + i + NF, remaining1); - hn::StoreN(hn::Add(x0, out0), df, out + i, remaining); - hn::StoreN(hn::Add(x1, out1), df, out + i + NF, remaining1); - } + using DF = hn::ScalableTag; + using VF = hn::Vec; + Decompress1AndCompressInplace(DF(), out, size, x, + [&](DF /*df*/, VF out, VF x) + HWY_ATTR -> VF { return hn::Add(x, out); }); } // Simple loops unless/until batch sizes are large enough to parallelize. @@ -588,7 +501,7 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, CallUpcasted(&weights, [&](const auto* weights_t) { SmallParallelFor( activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { - RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0, + RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), out.Row(token_idx), activations.Cols(), ctx.profiler, worker); }); }); @@ -603,7 +516,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, CallUpcasted(&weights, [&](const auto* weights_t) { SmallParallelFor( inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { - RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx), + RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), inout.Cols(), ctx.profiler, worker); }); }); @@ -645,38 +558,15 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, static const auto zone = p.AddZone("Ops.MulByConst"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - const size_t NF = hn::Lanes(df); - using VF = hn::Vec; + using DF = hn::ScalableTag; + using VF = hn::Vec; - const VF v_c = hn::Set(df, c); - const auto packed_x = MakeSpan(x, size); + const VF vc = hn::Set(DF(), c); + const VF* HWY_RESTRICT pc = &vc; - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - VF x0, x1; - Decompress2(df, packed_x, i, x0, x1); - x0 = hn::Mul(x0, v_c); - x1 = hn::Mul(x1, v_c); - Compress2(df, x0, x1, packed_x, i); - } - } - - const size_t remaining = size - i; - HWY_DASSERT(remaining < 2 * NF); - if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; - // Ensure the second vector is zeroed even if remaining <= NF. - hn::Store(hn::Zero(df), df, buf_x + NF); - DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); - VF x0 = hn::Load(df, buf_x); - VF x1 = hn::Load(df, buf_x + NF); - x0 = hn::Mul(x0, v_c); - x1 = hn::Mul(x1, v_c); - Compress2(df, x0, x1, MakeSpan(buf_x, 2 * NF), 0); - hwy::CopyBytes(buf_x, x + i, remaining * sizeof(XT)); - } + DecompressAndCompressInplace(DF(), x, size, + [pc](DF /*df*/, VF x) + HWY_ATTR -> VF { return hn::Mul(x, *pc); }); } // Same as above, but with a separate output. Same as below without the add. @@ -687,42 +577,18 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( static const auto zone = p.AddZone("Ops.MulByConstTo"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - const size_t NF = hn::Lanes(df); - using VF = hn::Vec; + using DF = hn::ScalableTag; + using VF = hn::Vec; - const VF v_c = hn::Set(df, c); - const auto packed_x = MakeSpan(x, size); - const auto packed_out = MakeSpan(out, size); + const VF vc = hn::Set(DF(), c); + const VF* HWY_RESTRICT pc = &vc; - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - VF x0, x1; - Decompress2(df, packed_x, i, x0, x1); - const VF out0 = hn::Mul(x0, v_c); - const VF out1 = hn::Mul(x1, v_c); - Compress2(df, out0, out1, packed_out, i); - } - } - - const size_t remaining = size - i; - HWY_DASSERT(remaining < 2 * NF); - if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; - HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; - // Ensure the second vector is zeroed even if remaining <= NF. - hn::Store(hn::Zero(df), df, buf_x + NF); - DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); - const VF x0 = hn::Load(df, buf_x); - const VF x1 = hn::Load(df, buf_x + NF); - const VF out0 = hn::Mul(x0, v_c); - const VF out1 = hn::Mul(x1, v_c); - Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); - hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT)); - } + Decompress1AndCompressTo(DF(), out, size, x, + [pc](DF /*df*/, VF x) + HWY_ATTR -> VF { return hn::Mul(x, *pc); }); } +// out[i] += x[i] * c. template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, @@ -730,45 +596,16 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - const size_t NF = hn::Lanes(df); - using VF = hn::Vec; + using DF = hn::ScalableTag; + using VF = hn::Vec; - const VF v_c = hn::Set(df, c); - const auto packed_x = MakeSpan(x, size); - const auto packed_out = MakeSpan(out, size); + const VF vc = hn::Set(DF(), c); + const VF* HWY_RESTRICT pc = &vc; - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - VF x0, x1, out0, out1; - Decompress2(df, packed_x, i, x0, x1); - Decompress2(df, packed_out, i, out0, out1); - out0 = hn::MulAdd(x0, v_c, out0); - out1 = hn::MulAdd(x1, v_c, out1); - Compress2(df, out0, out1, packed_out, i); - } - } - - const size_t remaining = size - i; - HWY_DASSERT(remaining < 2 * NF); - if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; - HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; - // Ensure the second vectors are zeroed even if remaining <= NF. - hn::Store(hn::Zero(df), df, buf_x + NF); - hn::Store(hn::Zero(df), df, buf_out + NF); - DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); - DecompressAndZeroPad(df, packed_out, i, buf_out, remaining); - const VF x0 = hn::Load(df, buf_x); - const VF x1 = hn::Load(df, buf_x + NF); - VF out0 = hn::Load(df, buf_out); - VF out1 = hn::Load(df, buf_out + NF); - out0 = hn::MulAdd(x0, v_c, out0); - out1 = hn::MulAdd(x1, v_c, out1); - Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); - hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT)); - } + Decompress1AndCompressInplace(DF(), out, size, x, + [&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF { + return hn::MulAdd(x, *pc, out); + }); } // See below for a specialized version for top-1 sampling. @@ -913,15 +750,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; + using DF = hn::ScalableTag; + using VF = hn::Vec; - const float inv_cap = 1.0f / cap; + const VF vcap = hn::Set(DF(), cap); + const VF vinv_cap = hn::Set(DF(), 1.0f / cap); + const VF* HWY_RESTRICT pcap = &vcap; + const VF* HWY_RESTRICT pinv_cap = &vinv_cap; - hn::Transform(D(), x, size, [cap, inv_cap](D d, V v) HWY_ATTR { - return hn::Mul(hn::Set(d, cap), - hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); - }); + DecompressAndCompressInplace( + DF(), x, size, [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF { + return hn::Mul(*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap))); + }); } // Calls LogitsSoftCap if cap != 0.0f. diff --git a/ops/ops_test.cc b/ops/ops_test.cc index e935ccf..7e63482 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -47,6 +47,7 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h +#include "compression/test_util-inl.h" #include "ops/ops-inl.h" #include "hwy/tests/test_util-inl.h" @@ -83,48 +84,6 @@ T Random(hwy::RandomState& rng) { HWY_MAX(hwy::ConvertScalarTo(hwy::LowestValue()), val)); } -HWY_NOINLINE void SimpleAddFrom(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, size_t size) { - for (size_t i = 0; i < size; ++i) { - x[i] += other[i]; - } -} - -HWY_NOINLINE void SimpleMulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, size_t size) { - for (size_t i = 0; i < size; ++i) { - x[i] *= other[i]; - } -} - -HWY_NOINLINE void SimpleMulByConst(float c, float* HWY_RESTRICT x, - size_t size) { - for (size_t i = 0; i < size; ++i) { - x[i] *= c; - } -} - -HWY_NOINLINE void SimpleMulByConstAndAdd(float c, const float* HWY_RESTRICT x, - float* HWY_RESTRICT out, size_t size) { - for (size_t i = 0; i < size; ++i) { - out[i] += x[i] * c; - } -} - -HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) { - HWY_DASSERT(size != 0); - float sum = 0.0; - const float maxval = *std::max_element(x, x + size); - for (size_t i = 0; i < size; ++i) { - x[i] = std::exp(x[i] - maxval); - sum += x[i]; - } - const float scale = 1.0f / sum; - for (size_t i = 0; i < size; ++i) { - x[i] *= scale; - } -} - template HWY_NOINLINE std::discrete_distribution SourceCreateDistribution( std::array& top_k, float temperature) { @@ -141,7 +100,8 @@ HWY_NOINLINE std::discrete_distribution SourceCreateDistribution( return std::discrete_distribution(std::begin(top_k), std::end(top_k)); } -struct TestAddFrom { +class TestAddFrom { + public: template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { @@ -171,9 +131,24 @@ struct TestAddFrom { hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } + + private: + template + static HWY_NOINLINE void SimpleAddFrom(const T1* HWY_RESTRICT other, + T2* HWY_RESTRICT x, size_t size) { + for (size_t i = 0; i < size; ++i) { + x[i] = hwy::ConvertScalarTo(hwy::ConvertScalarTo(x[i]) + + hwy::ConvertScalarTo(other[i])); + } + } }; -struct TestMulByConstAndAdd { +void TestAllAddFrom() { + hn::ForPartialVectors>()(float()); +} + +class TestMulByConstAndAdd { + public: template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { @@ -204,9 +179,27 @@ struct TestMulByConstAndAdd { hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } + + private: + template + static HWY_NOINLINE void SimpleMulByConstAndAdd(float c, + const T1* HWY_RESTRICT x, + T2* HWY_RESTRICT out, + size_t size) { + for (size_t i = 0; i < size; ++i) { + out[i] = hwy::ConvertScalarTo(hwy::ConvertScalarTo(out[i]) + + hwy::ConvertScalarTo(x[i]) * c); + } + } }; -struct TestMulByConst { +void TestAllMulByConstAndAdd() { + hn::ForPartialVectors>()( + float()); +} + +class TestMulByConst { + public: template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { @@ -234,9 +227,61 @@ struct TestMulByConst { hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); } + + private: + template + static HWY_NOINLINE void SimpleMulByConst(float c, T1* HWY_RESTRICT x, + size_t size) { + for (size_t i = 0; i < size; ++i) { + x[i] = hwy::ConvertScalarTo(hwy::ConvertScalarTo(x[i]) * c); + } + } }; -struct TestSoftmax { +void TestAllMulByConst() { + hn::ForPartialVectors>()(float()); +} + +struct TestMulByConstTo { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + if (misalign_b == 0) return; + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pactual = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe && pactual); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* actual = pe.get() + misalign_a; + + T constant = Random(rng); + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = hwy::ConvertScalarTo(hwy::ConvertScalarTo(x[i]) * + hwy::ConvertScalarTo(constant)); + } + + MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(), + /*worker=*/0); + + hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET), + __FILE__, __LINE__); + } +}; + +void TestAllMulByConstTo() { + hn::ForPartialVectors>()(float()); +} + +class TestSoftmax { + public: template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { @@ -270,8 +315,27 @@ struct TestSoftmax { } ASSERT_NEAR(sum, 1.0, 1e-6); } + + private: + static HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) { + HWY_DASSERT(size != 0); + float sum = 0.0; + const float maxval = *std::max_element(x, x + size); + for (size_t i = 0; i < size; ++i) { + x[i] = std::exp(x[i] - maxval); + sum += x[i]; + } + const float scale = 1.0f / sum; + for (size_t i = 0; i < size; ++i) { + x[i] *= scale; + } + } }; +void TestAllSoftmax() { + hn::ForPartialVectors>()(float()); +} + template struct TestCreateDistribution { void operator()(hwy::RandomState& rng) { @@ -291,43 +355,60 @@ struct TestCreateDistribution { } }; -void TestAllAddFrom() { - hn::ForPartialVectors>()(float()); -} - -void TestAllMulByConst() { - hn::ForPartialVectors>()(float()); -} - -void TestAllMulByConstAndAdd() { - hn::ForPartialVectors>()( - float()); -} - -void TestAllSoftmax() { - hn::ForPartialVectors>()(float()); -} - void TestAllCreateDistribution() { TestCreateDistribution<2048>(); TestCreateDistribution<5000>(); } -void TestSigmoid() { - std::vector values; - for (int i = -150; i <= 150; ++i) { - values.push_back(.1f * i); - } - std::vector result = values; - Sigmoid(result.data(), result.size()); +struct TestSigmoid { + template + void operator()(T, D) const { + std::vector values; + for (int i = -150; i <= 150; ++i) { + values.push_back(hwy::ConvertScalarTo(.1f * i)); + } + std::vector result = values; + Sigmoid(result.data(), result.size()); - for (size_t i = 0; i < values.size(); i++) { - const float max_error = 0.00007; - float value = values[i]; - float approx = result[i]; - float expected = (1 / (1 + std::exp(-values[i]))); - EXPECT_NEAR(approx, expected, max_error) << "Input: " << value; + for (size_t i = 0; i < values.size(); i++) { + const float max_error = IsBF16() ? 0.2f : 0.00007f; + const float value = hwy::ConvertScalarTo(values[i]); + const float actual = hwy::ConvertScalarTo(result[i]); + const float expected = (1 / (1 + std::exp(-value))); + EXPECT_NEAR(expected, actual, max_error) + << (IsBF16() ? "bf16" : "float"); + } } +}; + +static HWY_NOINLINE void TestAllSigmoid() { + ForeachActivationType1(hn::ScalableTag()); +} + +struct TestGelu { + template + void operator()(T, D) const { + std::vector values; + for (int i = -150; i <= 150; ++i) { + values.push_back(hwy::ConvertScalarTo(.1f * i)); + } + std::vector result = values; + Gelu(result.data(), result.size()); + + for (size_t i = 0; i < values.size(); i++) { + const float max_error = IsBF16() ? 0.2f : 0.00007f; + const float x = hwy::ConvertScalarTo(values[i]); + const float actual = hwy::ConvertScalarTo(result[i]); + const float expected = + x * (0.5f + 0.5f * tanh(x * (0.79788f + 0.035677f * x * x))); + EXPECT_NEAR(expected, actual, max_error) + << (IsBF16() ? "bf16" : "float"); + } + } +}; + +static HWY_NOINLINE void TestAllGelu() { + ForeachActivationType1(hn::ScalableTag()); } static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( @@ -421,7 +502,8 @@ void TestRopeAndMulBy() { } template -HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) { +static HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, + size_t size) { double sum = 0.0; for (size_t i = 0; i < size; ++i) { const float f = hwy::ConvertScalarTo(a[i]); @@ -431,9 +513,11 @@ HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) { } // Supports bf16 and f32 inputs/outputs, which can be in-place. +// Shared between TestRMSNorm and TestRMSNormInplace. template -HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight, - OT* out, size_t size) { +static HWY_NOINLINE void ScalarRMSNorm(const XT* x, + const WT* HWY_RESTRICT weight, OT* out, + size_t size) { constexpr float kEps = 1e-6f; float ss = ScalarSquaredL2(x, size); ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); @@ -445,42 +529,73 @@ HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight, } } -template -void TestRMSNorm(hwy::RandomState& rng) { - constexpr size_t kSize = 128; - HWY_ALIGN XT vec[kSize]; - HWY_ALIGN WT weight[kSize]; - HWY_ALIGN OT expected[kSize]; - HWY_ALIGN OT actual[kSize]; +struct TestRMSNorm { + template + void operator()(XT, WT, OT, D) const { + hwy::RandomState rng; - for (size_t i = 0; i < kSize; ++i) { - vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - } + constexpr size_t kSize = 128; + HWY_ALIGN XT vec[kSize]; + HWY_ALIGN WT weight[kSize]; + HWY_ALIGN OT expected[kSize]; + HWY_ALIGN OT actual[kSize]; - ScalarRMSNorm(vec, weight, expected, kSize); - RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); + for (size_t i = 0; i < kSize; ++i) { + vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + } - for (size_t i = 0; i < kSize; i++) { - const float e = hwy::ConvertScalarTo(expected[i]); - const float a = hwy::ConvertScalarTo(actual[i]); - if (!IsNear(e, a, 1e-5f)) { - HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), - TypeName(), TypeName(), i, e, a); + ScalarRMSNorm(vec, weight, expected, kSize); + RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); + + for (size_t i = 0; i < kSize; i++) { + const float e = hwy::ConvertScalarTo(expected[i]); + const float a = hwy::ConvertScalarTo(actual[i]); + if (!IsNear(e, a, 1e-5f)) { + HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), + TypeName(), TypeName(), i, e, a); + } } } -} +}; void TestAllRMSNorm() { - hwy::RandomState rng; - TestRMSNorm(rng); - TestRMSNorm(rng); - TestRMSNorm(rng); - TestRMSNorm(rng); - TestRMSNorm(rng); - TestRMSNorm(rng); - TestRMSNorm(rng); - TestRMSNorm(rng); + ForeachActivationType3(hn::ScalableTag()); +} + +struct TestRMSNormInplace { + template + void operator()(XT, WT, D) const { + hwy::RandomState rng; + + constexpr size_t kSize = 128; + HWY_ALIGN XT expected[kSize]; + HWY_ALIGN XT actual[kSize]; + HWY_ALIGN WT weight[kSize]; + + for (size_t i = 0; i < kSize; ++i) { + expected[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + actual[i] = expected[i]; + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + } + + ScalarRMSNorm(expected, weight, expected, kSize); + RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(), + /*worker=*/0); + + for (size_t i = 0; i < kSize; i++) { + const float e = hwy::ConvertScalarTo(expected[i]); + const float a = hwy::ConvertScalarTo(actual[i]); + if (!IsNear(e, a, 1e-5f)) { + HWY_ABORT("RMSNormInplace %s %s mismatch at %zu: %E %E\n", + TypeName(), TypeName(), i, e, a); + } + } + } +}; + +void TestAllRMSNormInplace() { + ForeachActivationType2(hn::ScalableTag()); } void TestLayerNormSimple() { @@ -497,91 +612,92 @@ void TestLayerNormSimple() { for (size_t i = 0; i < kSize; i++) { const float max_error = 1e-6f; - float value = values[i]; float res = result[i]; // out = (x - 0.0) * 1.2 * 0.9999995 + 0.1 = 1.2999994 / -1.0999994; float expected = (i % 2 == 0) ? 1.2999994f : -1.0999994f; - EXPECT_NEAR(res, expected, max_error) << "Input: " << value; + EXPECT_NEAR(res, expected, max_error); } } -// Computes mean mu and mean of squares mu2 of a vector. Used in -// ScalarLayerNorm. -template -HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, double& mu, - double& mu2) { - HWY_ASSERT(size > 0); - double sum = 0.0; - double sum2 = 0.0; - for (size_t i = 0; i < size; ++i) { - const float f = hwy::ConvertScalarTo(a[i]); - sum += f; - sum2 += f * f; - } - mu = sum / size; - mu2 = sum2 / size; -} +class TestLayerNorm { + public: + template + void operator()(XT, WT, OT, D) const { + hwy::RandomState rng; + constexpr size_t kSize = 128; + XT vec[kSize]; + WT weight[kSize]; + WT bias[kSize]; + OT expected[kSize]; + OT actual[kSize]; -// Compare py/flax/linen/normalization.py. -// out = (x - mean) * scale * rsqrt(var + epsilon) + bias -template -HWY_NOINLINE void ScalarLayerNorm(const XT* x, const WT* HWY_RESTRICT scale, - const WT* HWY_RESTRICT bias, OT* out, - size_t size) { - constexpr double kEps = 1e-6; - double mu, mu2; - ScalarMus(x, size, mu, mu2); - double var = mu2 - mu * mu; - constexpr double kZero = 0.0; - var = HWY_MAX(var, kZero); - var = 1.0 / sqrt(var + kEps); - for (size_t j = 0; j < size; j++) { - const float v = hwy::ConvertScalarTo(x[j]); - const float s = hwy::ConvertScalarTo(scale[j]); - const float b = hwy::ConvertScalarTo(bias[j]); - out[j] = hwy::ConvertScalarTo((v - mu) * s * var + b); - } -} + for (size_t i = 0; i < kSize; ++i) { + vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + bias[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + } -template -void TestLayerNorm(hwy::RandomState& rng) { - constexpr size_t kSize = 128; - XT vec[kSize]; - WT weight[kSize]; - WT bias[kSize]; - OT expected[kSize]; - OT actual[kSize]; + double expected_mu, expected_mu2; + ScalarMus(vec, kSize, expected_mu, expected_mu2); + double actual_mu, actual_mu2; + ComputeMoments(vec, kSize, actual_mu, actual_mu2); - for (size_t i = 0; i < kSize; ++i) { - vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - bias[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - } + ScalarLayerNorm(vec, weight, bias, expected, kSize); + LayerNorm(vec, weight, bias, actual, kSize); - double expected_mu, expected_mu2; - ScalarMus(vec, kSize, expected_mu, expected_mu2); - double actual_mu, actual_mu2; - ComputeMoments(vec, kSize, actual_mu, actual_mu2); - - ScalarLayerNorm(vec, weight, bias, expected, kSize); - LayerNorm(vec, weight, bias, actual, kSize); - - for (size_t i = 0; i < kSize; i++) { - const float e = hwy::ConvertScalarTo(expected[i]); - const float a = hwy::ConvertScalarTo(actual[i]); - if (!IsNear(e, a, 1e-5f)) { - HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), - TypeName(), TypeName(), i, e, a); + for (size_t i = 0; i < kSize; i++) { + const float e = hwy::ConvertScalarTo(expected[i]); + const float a = hwy::ConvertScalarTo(actual[i]); + if (!IsNear(e, a, 1e-5f)) { + HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), + TypeName(), TypeName(), i, e, a); + } } } -} + + private: + // Computes mean mu and mean of squares mu2 of a vector. Used in + // ScalarLayerNorm. + template + static HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, + double& mu, double& mu2) { + HWY_ASSERT(size > 0); + double sum = 0.0; + double sum2 = 0.0; + for (size_t i = 0; i < size; ++i) { + const float f = hwy::ConvertScalarTo(a[i]); + sum += f; + sum2 += f * f; + } + mu = sum / size; + mu2 = sum2 / size; + } + + // Compare py/flax/linen/normalization.py. + // out = (x - mean) * scale * rsqrt(var + epsilon) + bias + template + static HWY_NOINLINE void ScalarLayerNorm(const XT* x, + const WT* HWY_RESTRICT scale, + const WT* HWY_RESTRICT bias, OT* out, + size_t size) { + constexpr double kEps = 1e-6; + double mu, mu2; + ScalarMus(x, size, mu, mu2); + double var = mu2 - mu * mu; + constexpr double kZero = 0.0; + var = HWY_MAX(var, kZero); + var = 1.0 / sqrt(var + kEps); + for (size_t j = 0; j < size; j++) { + const float v = hwy::ConvertScalarTo(x[j]); + const float s = hwy::ConvertScalarTo(scale[j]); + const float b = hwy::ConvertScalarTo(bias[j]); + out[j] = hwy::ConvertScalarTo((v - mu) * s * var + b); + } + } +}; void TestAllLayerNorm() { - hwy::RandomState rng; - TestLayerNorm(rng); - TestLayerNorm(rng); - TestLayerNorm(rng); - TestLayerNorm(rng); + ForeachActivationType3(hn::ScalableTag()); } void TestSampleTopK() { @@ -646,12 +762,15 @@ namespace gcpp { HWY_BEFORE_TEST(OpsTest); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); -HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu); HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNormInplace); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple); HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK); From 0d2e74d74aaecb1314f85d9848691d7ede5cb5b1 Mon Sep 17 00:00:00 2001 From: Marie White Date: Mon, 1 Sep 2025 23:46:07 -0700 Subject: [PATCH 18/65] Add MMOptions as an argument to Matmul. PiperOrigin-RevId: 802008198 --- ops/matmul-inl.h | 9 ++++----- ops/matmul.h | 4 ++-- ops/matmul_static-inl.h | 4 ++-- ops/matmul_static.h | 2 +- ops/matmul_test.cc | 3 ++- ops/ops-inl.h | 7 ++++--- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 8f91114..95277db 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1179,7 +1179,7 @@ struct MMImpl { template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C) { + MatPtrT& C, MMOptions options) { RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); const Allocator& allocator = env.ctx.allocator; @@ -1205,12 +1205,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; - // Default to nested parallelism. - const ParallelismType parallelism_type = ParallelismType::kNested; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), + options.parallelism_type); return &per_key; } @@ -1239,7 +1238,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.h b/ops/matmul.h index dda9673..620d382 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -67,8 +67,8 @@ enum class ParallelismType : uint8_t { }; struct MMOptions { - ParallelismType parallelism_type_ = ParallelismType::kNested; - uint8_t cluster_idx_ = 0; + ParallelismType parallelism_type = ParallelismType::kNested; + uint8_t cluster_idx = 0; }; struct MMSequentialPolicy { diff --git a/ops/matmul_static-inl.h b/ops/matmul_static-inl.h index 28b21cf..ba09e0c 100644 --- a/ops/matmul_static-inl.h +++ b/ops/matmul_static-inl.h @@ -28,8 +28,8 @@ #define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \ MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ const float* HWY_RESTRICT add, MatMulEnv& env, \ - MatPtrT& C) { \ - return MatMul(A, B, add, env, C); \ + MatPtrT& C, MMOptions options) { \ + return MatMul(A, B, add, env, C, options); \ } #if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \ diff --git a/ops/matmul_static.h b/ops/matmul_static.h index c06b87a..61dc505 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -35,7 +35,7 @@ #define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \ MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ const float* HWY_RESTRICT add, MatMulEnv& env, \ - MatPtrT& C); + MatPtrT& C, MMOptions options); // Passed to HWY_VISIT_TARGETS; declares all overloads for all targets. #define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 3a8528a..dc6f559 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -258,8 +258,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulSlow(A, BT, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. + MMOptions options; for (size_t rep = 0; rep < 16; ++rep) { - MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C); + MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options); AssertClose(A, BT, C_slow, C, env, line); if (per_key->autotune.Best()) break; } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index caf8041..0438bf7 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -63,9 +63,10 @@ namespace hn = hwy::HWY_NAMESPACE; template MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C) { - return CallUpcasted( - &B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); }); + MatPtrT& C, const MMOptions& options = MMOptions()) { + return CallUpcasted(&B, [&](const auto* B_t) { + return MatMulStatic(A, *B_t, add, env, C, options); + }); } HWY_INLINE double PackTokenAndProb(int32_t token, float prob) { From 27cb8e12d9dcd433d2e642ef91b8ad76c6e9bbb0 Mon Sep 17 00:00:00 2001 From: Marie White Date: Tue, 2 Sep 2025 00:02:18 -0700 Subject: [PATCH 19/65] Handle non-threading parallel policy. PiperOrigin-RevId: 802012517 --- ops/matmul-inl.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 95277db..2e4dcde 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1147,8 +1147,10 @@ struct MMImpl { MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( MMNestedParallelPolicy(), A, B, C_rows); break; - case ParallelismType::kNone: case ParallelismType::kSequential: + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMSequentialPolicy(), A, B, C_rows); + case ParallelismType::kNone: case ParallelismType::kCluster: HWY_ABORT("Parallelism type not implemented."); break; From 373722413279b8fc69334ef2c3d1ceb6e130d0d3 Mon Sep 17 00:00:00 2001 From: Marie White Date: Tue, 2 Sep 2025 00:14:05 -0700 Subject: [PATCH 20/65] Add in-cluster parallel policy. Update policy to include cluster_idx. PiperOrigin-RevId: 802016308 --- ops/matmul-inl.h | 43 +++++++++++++++----------- ops/matmul.h | 80 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 96 insertions(+), 27 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 2e4dcde..1d5dc5d 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -728,9 +728,10 @@ class MMPerPackage { public: MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, - size_t pkg_idx, const IndexRange& range_np) + size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) : args_(args), pkg_idx_(pkg_idx), + cluster_idx_(cluster_idx), range_np_(range_np), mr_(config.MR()), ranges_mc_(config.RangesOfMC(A.rows)), @@ -821,7 +822,8 @@ class MMPerPackage { // Similar to `loop_nc` below, but here we hoisted `A_view`. MMParallelPolicyT::ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + pkg_idx_, cluster_idx_, + [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -869,7 +871,8 @@ class MMPerPackage { MMParallelPolicyT::ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + pkg_idx_, cluster_idx_, + [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -901,7 +904,7 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -951,7 +954,7 @@ class MMPerPackage { } }; // loop_nc MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -1024,7 +1027,7 @@ class MMPerPackage { const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks, - pkg_idx_, + pkg_idx_, cluster_idx_, [&](const IndexRange& range_K, size_t worker) { do_range(all_M, range_K, worker); }); @@ -1032,7 +1035,8 @@ class MMPerPackage { } case MMParA::kM: MMParallelPolicyT::ForRangeMC( - args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) { + args_.env->ctx, all_M, pkg_idx_, cluster_idx_, + [&](size_t row_a, size_t worker) { do_range(IndexRange(row_a, row_a + 1), all_K, worker); }); break; @@ -1106,6 +1110,7 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; + const size_t cluster_idx_; // 0 for sequential and nested. const IndexRange range_np_; // From MMConfig: @@ -1135,23 +1140,26 @@ struct MMImpl { template static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows, const MMArgs& args, - const MMConfig& config, - ParallelismType parallelism_type) { + const MMConfig& config, MMOptions options) { PROFILER_ZONE("MM.DoMatMul"); const size_t pkg_idx = 0; HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - switch (parallelism_type) { + switch (options.parallelism_type) { case ParallelismType::kNested: - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMNestedParallelPolicy(), A, B, C_rows); + HWY_DASSERT(options.cluster_idx == 0); + MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, + range_np)(MMNestedParallelPolicy(), A, B, C_rows); break; case ParallelismType::kSequential: - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMSequentialPolicy(), A, B, C_rows); - case ParallelismType::kNone: + MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, + range_np)(MMSequentialPolicy(), A, B, C_rows); case ParallelismType::kCluster: + MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, + range_np)(MMClusterParallelPolicy(), A, B, C_rows); + break; + default: HWY_ABORT("Parallelism type not implemented."); break; } @@ -1210,8 +1218,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), - options.parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options); return &per_key; } @@ -1240,7 +1247,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, cfg, options); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.h b/ops/matmul.h index 620d382..5c526de 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -81,9 +81,10 @@ struct MMSequentialPolicy { template static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - const Func& func) { + size_t cluster_idx, const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + + cluster_idx * ctx.pools.MaxWorkersPerCluster(); func(range_np, base_idx); } @@ -91,8 +92,10 @@ struct MMSequentialPolicy { static void ForRangesMC_NC(ThreadingContext& ctx, const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, - size_t pkg_idx, const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + size_t pkg_idx, size_t cluster_idx, + const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + + cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { const IndexRange range_mc = ranges_mc.Range(i); @@ -105,14 +108,68 @@ struct MMSequentialPolicy { template static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + size_t pkg_idx, size_t cluster_idx, const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + + cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { func(row_a, base_idx); } } }; +struct MMClusterParallelPolicy { + template + static void ForPkg(ThreadingContext& ctx, const size_t max_packages, + const Func& func) { + func(/*pkg_idx=*/0); + } + + template + static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, + size_t cluster_idx, const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + const IndexRangePartition worker_ranges = StaticPartition( + range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + ParallelizeOneRange(worker_ranges, cluster, + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, worker); + }); + } + + template + static void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t pkg_idx, size_t cluster_idx, + const Func& func) { + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + + // Low-batch: avoid Divide/Remainder. + if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { + ParallelizeOneRange(ranges_nc, cluster, + [&](const IndexRange& range_nc, size_t thread) { + func(ranges_mc.Range(0), range_nc, thread); + }); + } else { + ParallelizeTwoRanges( + ranges_mc, ranges_nc, cluster, + [&](const IndexRange& range_mc, const IndexRange& range_nc, + size_t thread) { func(range_mc, range_nc, thread); }); + } + } + + template + static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t pkg_idx, size_t cluster_idx, const Func& func) { + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + cluster.Run(range_mc.begin(), range_mc.end(), + [&](uint64_t row_a, size_t thread) { func(row_a, thread); }); + } +}; + struct MMNestedParallelPolicy { template static void ForPkg(ThreadingContext& ctx, const size_t max_packages, @@ -132,10 +189,11 @@ struct MMNestedParallelPolicy { // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. + // `cluster_idx` is not used here as all clusters within a package are used. template static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - const Func& func) { + size_t /*cluster_idx*/, const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); @@ -175,11 +233,13 @@ struct MMNestedParallelPolicy { // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // rows). Calls `func(range_mc, range_nc, worker)`. + // `cluster_idx` is not used here as all clusters within a package are used. template static void ForRangesMC_NC(ThreadingContext& ctx, const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, - size_t pkg_idx, const Func& func) { + size_t pkg_idx, size_t /*cluster_idx*/, + const Func& func) { const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. @@ -221,9 +281,11 @@ struct MMNestedParallelPolicy { } // Calls `func(row_a, worker)` in parallel. + // `cluster_idx` is not used here as all clusters within a package are used. template static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, const Func& func) { + size_t pkg_idx, size_t /*cluster_idx*/, + const Func& func) { const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); ctx.pools.Pool(pkg_idx).Run( range_mc.begin(), range_mc.end(), From 1e3c853e80f242a09685c4f891cc35e456e34282 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 2 Sep 2025 01:39:28 -0700 Subject: [PATCH 21/65] Add ParallelFor wrapper function and one new mode Move ParallelismType from matmul.h to threading.h Replace SmallParallelFor with ParallelFor and the new mode PiperOrigin-RevId: 802038452 --- BUILD.bazel | 1 + gemma/attention.cc | 6 ++--- gemma/gemma-inl.h | 44 ++++++++++++++++++++-------------- ops/matmul-inl.h | 7 +++--- ops/matmul.h | 11 +-------- ops/ops-inl.h | 42 +++++++++++++++++++-------------- util/threading.h | 59 +++++++++++++++++++++++++++++++++++++++------- 7 files changed, 110 insertions(+), 60 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index e141e95..ce4cffb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -264,6 +264,7 @@ cc_library( ":allocator", ":basics", ":mat", + ":threading", ":threading_context", "//compression:compress", "@highway//:bit_set", diff --git a/gemma/attention.cc b/gemma/attention.cc index 13681d0..bd76329 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -233,9 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); - // Full parallelism is helpful, SmallParallelFor is insufficient. - ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, - ctx.pools, func); + // Full parallelism is helpful, kAcrossClusters is insufficient. + NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, + ctx.pools, func); } } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ed0750a..80ec0ee 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -66,31 +66,39 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1, // No C2 multiplier. template -void ActivationBatched(ActivationType activation, Mat& c1, - ThreadingContext& ctx) { +void ActivationBatched( + ActivationType activation, Mat& c1, ThreadingContext& ctx, + size_t cluster_idx = 0, + ParallelismType parallelism = ParallelismType::kAcrossClusters) { using T = typename Mat::T; - SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - // Cast to correct type so type deduction works. - Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), ctx.profiler, worker); - }); + ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + [&](uint64_t task, size_t worker) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(task), + static_cast(nullptr), c1.Cols(), + ctx.profiler, worker); + }); } template -HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat1& c1, - const Mat2* c2, ThreadingContext& ctx) { +HWY_NOINLINE void ActivationBatched( + ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, + size_t cluster_idx = 0, + ParallelismType parallelism = ParallelismType::kAcrossClusters) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { - SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), - ctx.profiler, worker); - }); + ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), + ctx.profiler, worker); + }); } else { // No multiplier - SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), - static_cast(nullptr), c1.Cols(), - ctx.profiler, worker); - }); + ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), + static_cast(nullptr), + c1.Cols(), ctx.profiler, worker); + }); } } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 1d5dc5d..53dfb05 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1155,12 +1155,13 @@ struct MMImpl { case ParallelismType::kSequential: MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, range_np)(MMSequentialPolicy(), A, B, C_rows); - case ParallelismType::kCluster: + case ParallelismType::kWithinCluster: MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, range_np)(MMClusterParallelPolicy(), A, B, C_rows); break; default: - HWY_ABORT("Parallelism type not implemented."); + HWY_ABORT("Parallelism type %s not implemented.", + static_cast(options.parallelism_type)); break; } } @@ -1189,7 +1190,7 @@ struct MMImpl { template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C, MMOptions options) { + MatPtrT& C, MMOptions options = MMOptions()) { RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); const Allocator& allocator = env.ctx.allocator; diff --git a/ops/matmul.h b/ops/matmul.h index 5c526de..11262bc 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -27,6 +27,7 @@ // IWYU pragma: begin_exports #include "util/basics.h" #include "util/mat.h" +#include "util/threading.h" #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" @@ -56,16 +57,6 @@ static constexpr size_t kMaxMR = 4; IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, size_t N, size_t sizeof_TC, size_t nr); -enum class ParallelismType : uint8_t { - kNone, - // No parallelism. - kSequential, - // Parallelism at cluster level. - kCluster, - // Parallelism at package level. - kNested, -}; - struct MMOptions { ParallelismType parallelism_type = ParallelismType::kNested; uint8_t cluster_idx = 0; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0438bf7..0173ee8 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -494,14 +494,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, // Simple loops unless/until batch sizes are large enough to parallelize. template void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, - MatPtrT& out, ThreadingContext& ctx) { + MatPtrT& out, ThreadingContext& ctx, + size_t cluster_idx = 0) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - SmallParallelFor( - activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { + ParallelFor( + ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools, + cluster_idx, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), out.Row(token_idx), activations.Cols(), ctx.profiler, worker); }); @@ -510,13 +512,14 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, template void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, - ThreadingContext& ctx) { + ThreadingContext& ctx, size_t cluster_idx = 0) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - SmallParallelFor( - inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { + ParallelFor( + ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx, + [&](uint64_t token_idx, size_t worker) { RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), inout.Cols(), ctx.profiler, worker); }); @@ -542,13 +545,14 @@ void LayerNormBatched(const MatPtrT& x, const MatPtr& weight, template static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, - ThreadingContext& ctx) { + ThreadingContext& ctx, + size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); - SmallParallelFor(out.Rows(), ctx.pools, - [&](uint64_t token_idx, size_t worker) { - AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), - ctx.profiler, worker); - }); + ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools, + cluster_idx, [&](uint64_t token_idx, size_t worker) { + AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), + ctx.profiler, worker); + }); } template @@ -776,13 +780,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, - ThreadingContext& ctx) { + ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; - SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - if (non_eos.Get(task)) { - LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); - } - }); + ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools, + cluster_idx, [&](uint64_t task, size_t worker) { + if (non_eos.Get(task)) { + LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, + worker); + } + }); } static HWY_NOINLINE HWY_MAYBE_UNUSED size_t diff --git a/util/threading.h b/util/threading.h index 0a57ddb..ef4f1c7 100644 --- a/util/threading.h +++ b/util/threading.h @@ -326,7 +326,7 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // over clusters of ONE package, then within each cluster. template -void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { +void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { // Even if there are multiple packages, we only use the first. const size_t pkg_idx = 0; @@ -356,14 +356,57 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { }); } -// As above, but for lightweight tasks. Uses only one pool. -template -void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { - // Even if there are multiple packages, we only use the first. - const size_t pkg_idx = 0; +// Which pool(s) to use for parallelizing: +enum class ParallelismType : uint8_t { + // None: single-threaded loop on the calling thread. + kSequential, + // One thread per cluster within the first package; or one per core if there + // is only one cluster. Use for few or lightweight tasks, or to maximize + // memory bandwidth availability. + kAcrossClusters, + // All cores within the cluster identified by `cluster_idx`. Use if already + // within a `kAcrossClusters` parallel-for, or if latency is more important + // than memory bandwidth. + kWithinCluster, + // First statically partitions `kAcrossClusters`, then `kWithinCluster`. This + // utilizes all cores, but has higher fork-join overhead (two barriers); use + // if there are many or heavy tasks. + kNested, +}; - pools.Pool(pkg_idx).Run( - 0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); }); +// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the +// number/type of workers determined by `parallelism`. `cluster_idx` is only +// used if `parallelism == kWithinCluster`. +template +void ParallelFor(ParallelismType parallelism, size_t num_tasks, + NestedPools& pools, size_t cluster_idx, const Func& func) { + if (cluster_idx != 0) { + // If already running across clusters, must not use across-cluster modes. + HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters && + parallelism != ParallelismType::kNested); + } + + const size_t pkg_idx = 0; + switch (parallelism) { + case ParallelismType::kSequential: + for (size_t task = 0; task < num_tasks; ++task) { + func(task, /*worker=*/0); + } + return; + + case ParallelismType::kAcrossClusters: + return pools.Pool(pkg_idx).Run( + 0, num_tasks, + [&](uint64_t task, size_t worker) { func(task, worker); }); + + case ParallelismType::kWithinCluster: + return pools.Cluster(pkg_idx, cluster_idx) + .Run(0, num_tasks, + [&](uint64_t task, size_t worker) { func(task, worker); }); + + case ParallelismType::kNested: + return NestedParallelFor(num_tasks, pools, func); + } } } // namespace gcpp From b7b3d353db453ca6b6165799644a9e52e840852f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 2 Sep 2025 04:28:49 -0700 Subject: [PATCH 22/65] Simplify MatMul: remove F32 special case (build time) Also move kMaxM into separate kMaxBatchSize PiperOrigin-RevId: 802086590 --- gemma/gemma_args.h | 16 +- ops/matmul-inl.h | 402 ++++++++++++++------------------------------- ops/matmul.cc | 27 ++- ops/matmul.h | 18 +- ops/matmul_test.cc | 6 +- util/basics.h | 5 +- 6 files changed, 157 insertions(+), 317 deletions(-) diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 16c9595..2a49349 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -248,17 +248,17 @@ struct InferenceArgs : public ArgsBase { runtime_config.max_generated_tokens = max_generated_tokens; runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size; - if (prefill_tbatch_size > MMStorage::kMaxM) { + if (prefill_tbatch_size > kMaxBatchSize) { HWY_ABORT( - "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " - "or increase the constant in MMStorage.\n", - prefill_tbatch_size, MMStorage::kMaxM); + "prefill_tbatch_size %zu > kMaxBatchSize %zu: specify a " + "smaller value, or increase kMaxBatchSize.\n", + prefill_tbatch_size, kMaxBatchSize); } - if (decode_qbatch_size > MMStorage::kMaxM) { + if (decode_qbatch_size > kMaxBatchSize) { HWY_ABORT( - "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " - "or increase the constant in MMStorage.\n", - decode_qbatch_size, MMStorage::kMaxM); + "decode_qbatch_size %zu > kMaxBatchSize %zu: specify a " + "smaller value, or increase kMaxBatchSize.\n", + decode_qbatch_size, kMaxBatchSize); } runtime_config.temperature = temperature; diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 53dfb05..a9685e2 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -96,28 +96,27 @@ struct MMSetC {}; struct MMAddC {}; // Stores horizontal sums of up to 16 vectors via transpose. -template +template class MMStoreHorizontalSumsIntoC { public: static_assert(kNR == 4); // for `StoreInterleaved4` - // Computes horizontal sums of `kRowsAC x kNR` vectors and stores into - // `C` starting at `(row_c, col_c)`. - // + // Given 16 (`kRowsAC x kNR`) full vectors of 32-bit float, returns four + // 4-wide float vectors with their horizontal sums. // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a // transposed B row vector indexed by `c`. Their elements are thus a subset // of the terms of the dot product constituting the final `C[r, c]` result. // Thus we compute the horizontal sums of each `Crc`. The elements may be // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but // this does not change their horizontal sum. - template , typename TC> - HWY_INLINE void operator()(DF df, // - VF C00, VF C01, VF C02, VF C03, // - VF C10, VF C11, VF C12, VF C13, // - VF C20, VF C21, VF C22, VF C23, // - VF C30, VF C31, VF C32, VF C33, // - const size_t row_c, const size_t col_c, - const MMArgs& args, RowPtrs C_rows) const { + template , class D4 = hn::Full128, + class V4 = hn::Vec> + HWY_INLINE void Reduce4x4(DF df, // + VF C00, VF C01, VF C02, VF C03, // + VF C10, VF C11, VF C12, VF C13, // + VF C20, VF C21, VF C22, VF C23, // + VF C30, VF C31, VF C32, VF C33, // + V4& sum0, V4& sum1, V4& sum2, V4& sum3) { HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing @@ -133,14 +132,13 @@ class MMStoreHorizontalSumsIntoC { // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in // the elements of one V4. We have four independent rows `r`, hence the // code is effectively unrolled, which increases throughput. - const hn::CappedTag d4; - using V4 = hn::Vec; // Store to four elements per row of `partial`. // No loop is required because vectors are at least 4*32 bits. - V4 sum0 = MaybeLoad<0>(d4, N, buf); - V4 sum1 = MaybeLoad<1>(d4, N, buf); - V4 sum2 = MaybeLoad<2>(d4, N, buf); - V4 sum3 = MaybeLoad<3>(d4, N, buf); + const D4 d4; + sum0 = MaybeLoad<0>(d4, N, buf); + sum1 = MaybeLoad<1>(d4, N, buf); + sum2 = MaybeLoad<2>(d4, N, buf); + sum3 = MaybeLoad<3>(d4, N, buf); for (size_t lane = 1; lane < N; ++lane) { sum0 = MaybeAdd<0>(d4, N, sum0, buf + kNR * lane); @@ -148,13 +146,23 @@ class MMStoreHorizontalSumsIntoC { sum2 = MaybeAdd<2>(d4, N, sum2, buf + kNR * lane); sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane); } + } + + // Scales the dot-product terms and adds bias (if present) and stores the + // four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is + // `MMSetC`, the vectors are written as-is (first call, or small K). + // Otherwise, they are partial sums and are accumulated into C. + template , class Tag, typename TC> + HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, + const size_t row_c, const size_t col_c, + const MMArgs& args, RowPtrs C_rows) const { const V4 vscale = hn::Set(d4, args.scale); HWY_ALIGN static constexpr float kZero[4] = {}; const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c); } private: @@ -191,18 +199,20 @@ class MMStoreHorizontalSumsIntoC { } template , - typename TC> + class Tag, typename TC> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, - VF4 vadd, RowPtrs C_rows, + VF4 vadd, Tag, RowPtrs C_rows, const size_t row_c, const size_t col_c) { if constexpr (kRow < kRowsAC) { TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; const hn::Rebind dc4; - if constexpr (kAdd) { + if constexpr (hwy::IsSame()) { vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value - } // else: add bias (only once, the first time we store to C) - + } else { + static_assert(hwy::IsSame()); + // vadd remains the bias (added once, the first time we store to C) + } const VF4 out = hn::MulAdd(sum, vscale, vadd); hn::Store(TCFromF32(dc4, out), dc4, pos); } @@ -215,9 +225,9 @@ class MMKernel { // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template - static HWY_INLINE void A2C0(const StridedView A_view, - const StridedViewBF& B_view, size_t mr, + template + static HWY_INLINE void A2C0(const StridedViewBF A_view, + const StridedViewBF B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { @@ -357,34 +367,18 @@ class MMKernel { } } - // For A=F32, B=BF16 without native BF16 dot product: one lane-crossing - // promotion is likely cheaper than AND+SHIFT for promoting odd/even BF. - // Caller already promoted B, so all inputs are F32. - template , HWY_IF_F32_D(DF)> - static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2, - VF b3, VF& C0, VF& C1, VF& C2, - VF& C3) { - HWY_DASSERT(!HWY_NATIVE_DOT_BF16); - C0 = hn::MulAdd(a, b0, C0); - C1 = hn::MulAdd(a, b1, C1); - C2 = hn::MulAdd(a, b2, C2); - C3 = hn::MulAdd(a, b3, C3); - } - // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. - // `B` is BF16, `A` and `C` can be F32 or BF16. - template - static HWY_INLINE void LoopKC(const StridedView A_view, - const StridedViewBF& B_view, size_t row_ac, + // `A` and `B` are always BF16, `C` can be F32 or BF16. + template + static HWY_INLINE void LoopKC(const StridedViewBF A_view, + const StridedViewBF B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; - - HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag()); HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_DASSERT(kRowsAC <= kMaxMR); @@ -393,10 +387,10 @@ class MMKernel { // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. static_assert(kNR == 4); - const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0); - const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; - const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; - const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; + const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0); + const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; + const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; + const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; const BF16* HWY_RESTRICT br0 = B_view.Row(0); const BF16* HWY_RESTRICT br1 = B_view.Row(1); const BF16* HWY_RESTRICT br2 = B_view.Row(2); @@ -416,8 +410,6 @@ class MMKernel { C33 = hn::Zero(df); size_t ikc = 0; - // The loop step is always NBF: for non-native BF16 with TA=F32, this - // entails 2x unrolling, which helps a little. const HWY_LANES_CONSTEXPR size_t kc_step = NBF; if (kc >= kc_step) { HWY_UNROLL(1) @@ -432,10 +424,6 @@ class MMKernel { const VBF b2 = hn::LoadU(dbf, br2 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc); - // Should only get here if `A` is BF16, otherwise `DecompressA` would - // convert to BF16 and `A_view` points to that. - HWY_DASSERT(IsBF16()); - { const VBF a0 = hn::Load(dbf, ar0 + ikc); ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, @@ -457,102 +445,40 @@ class MMKernel { C33); } } else { // !HWY_NATIVE_DOT_BF16 - if constexpr (IsBF16()) { - // When both are BF16, it is better to load promote odd/even, - // because lane-crossing promotion for both might be bottlenecked on - // shuffles. - VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; - { - const VBF b0 = hn::LoadU(dbf, br0 + ikc); - const VBF b1 = hn::LoadU(dbf, br1 + ikc); - const VBF b2 = hn::LoadU(dbf, br2 + ikc); - const VBF b3 = hn::LoadU(dbf, br3 + ikc); - b0e = hn::PromoteEvenTo(df, b0); - b1e = hn::PromoteEvenTo(df, b1); - b2e = hn::PromoteEvenTo(df, b2); - b3e = hn::PromoteEvenTo(df, b3); - b0o = FastPromoteOddTo(df, b0); - b1o = FastPromoteOddTo(df, b1); - b2o = FastPromoteOddTo(df, b2); - b3o = FastPromoteOddTo(df, b3); - } - - // Two rows at a time so we have 8 separate dependency chains, - // sufficient for IPC=2 and 4-cycle latency. - { - const VBF a0 = hn::Load(dbf, ar0 + ikc); - const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; - ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C00, C01, C02, C03, C10, C11, - C12, C13); - } - if constexpr (kRowsAC > 2) { - const VBF a2 = hn::Load(dbf, ar2 + ikc); - const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; - ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C20, C21, C22, C23, C30, C31, - C32, C33); - } - } else { // IsF32(): promote BF to 2xF32, F32*F32. - // Full-vector loads are a bit faster on SKX than half + PromoteTo. + // When both are BF16, it is better to load promote odd/even, + // because lane-crossing promotion for both might be bottlenecked on + // shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { const VBF b0 = hn::LoadU(dbf, br0 + ikc); const VBF b1 = hn::LoadU(dbf, br1 + ikc); const VBF b2 = hn::LoadU(dbf, br2 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc); - const VF b00 = hn::PromoteLowerTo(df, b0); - const VF b10 = hn::PromoteLowerTo(df, b1); - const VF b20 = hn::PromoteLowerTo(df, b2); - const VF b30 = hn::PromoteLowerTo(df, b3); - const VF b01 = hn::PromoteUpperTo(df, b0); - const VF b11 = hn::PromoteUpperTo(df, b1); - const VF b21 = hn::PromoteUpperTo(df, b2); - const VF b31 = hn::PromoteUpperTo(df, b3); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } - { - const VF a00 = hn::Load(df, ar0 + ikc); - ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a10 = hn::Load(df, ar1 + ikc); - ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C00, C01, C02, C03, C10, C11, C12, C13); - } - - // C00 is ready again. On SKX, this interleaved unrolling is faster - // than consuming all `b*1` at the end of the loop. - { - const VF a01 = hn::Load(df, ar0 + ikc + NA); - ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a11 = hn::Load(df, ar1 + ikc + NA); - ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, - C13); - } - - if constexpr (kRowsAC > 2) { - const VF a20 = hn::Load(df, ar2 + ikc); - ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a30 = hn::Load(df, ar3 + ikc); - ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C20, C21, C22, C23, C30, C31, C32, C33); - } - - if constexpr (kRowsAC > 2) { - const VF a21 = hn::Load(df, ar2 + ikc + NA); - ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a31 = hn::Load(df, ar3 + ikc + NA); - ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, - C33); - } } } } @@ -569,10 +495,6 @@ class MMKernel { const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); - // Should only get here if `A` is BF16, otherwise `DecompressA` would - // convert to BF16 and `A_view` points to that. - HWY_DASSERT(IsBF16()); - { const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, @@ -594,104 +516,39 @@ class MMKernel { C33); } } else { // !HWY_NATIVE_DOT_BF16 - if constexpr (IsBF16()) { - // When both are BF16, it is better to load promote odd/even, because - // lane-crossing promotion for both might be bottlenecked on shuffles. - VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; - { - const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); - const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); - const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); - const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); - b0e = hn::PromoteEvenTo(df, b0); - b1e = hn::PromoteEvenTo(df, b1); - b2e = hn::PromoteEvenTo(df, b2); - b3e = hn::PromoteEvenTo(df, b3); - b0o = FastPromoteOddTo(df, b0); - b1o = FastPromoteOddTo(df, b1); - b2o = FastPromoteOddTo(df, b2); - b3o = FastPromoteOddTo(df, b3); - } - - // Two rows at a time so we have 8 separate dependency chains, - // sufficient for IPC=2 and 4-cycle latency. - { - const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); - const VBF a1 = - kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0; - ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C00, C01, C02, C03, C10, C11, C12, - C13); - } - if constexpr (kRowsAC > 2) { - const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); - const VBF a3 = - kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2; - ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C20, C21, C22, C23, C30, C31, C32, - C33); - } - } else { // IsF32(): promote half-B to F32, F32*F32. + // When both are BF16, it is better to load promote odd/even, because + // lane-crossing promotion for both might be bottlenecked on shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); - const VF b00 = hn::PromoteLowerTo(df, b0); - const VF b10 = hn::PromoteLowerTo(df, b1); - const VF b20 = hn::PromoteLowerTo(df, b2); - const VF b30 = hn::PromoteLowerTo(df, b3); - const VF b01 = hn::PromoteUpperTo(df, b0); - const VF b11 = hn::PromoteUpperTo(df, b1); - const VF b21 = hn::PromoteUpperTo(df, b2); - const VF b31 = hn::PromoteUpperTo(df, b3); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } - const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA; - - { - const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, - C13); - } - - // C00 is ready again. On SKX, this interleaved unrolling is faster - // than consuming all `b*1` at the end of the loop. - { - const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, - C13); - } - - if constexpr (kRowsAC > 2) { - const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, - C33); - } - - if constexpr (kRowsAC > 2) { - const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, - C33); - } + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); + const VBF a1 = + kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o, + b3e, C00, C01, C02, C03, C10, C11, C12, C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); + const VBF a3 = + kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o, + b3e, C20, C21, C22, C23, C30, C31, C32, C33); } } } // remaining_kc != 0 @@ -699,16 +556,12 @@ class MMKernel { // This is a substantial fraction (about 1/3) of the total time, but is // called frequently, so do not add a profiler zone. - if constexpr (hwy::IsSame()) { - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } else { - static_assert(hwy::IsSame()); - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } + MMStoreHorizontalSumsIntoC horz; + const hn::Full128 d4; + hn::Vec sum0, sum1, sum2, sum3; + horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, + C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); + horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows); } }; @@ -717,15 +570,6 @@ class MMKernel { // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. // Its member variables avoid long argument lists in Do*(). class MMPerPackage { - // Decompression is only required for F32 A and native BF16 dot products. - // If A is already BF16, we can use a view. Padding is not required - // because `LoopKC` can handle non-vector multiples. `LoopKC` also contains - // a special case for F32 `A` and non-native BF16 dot products. - template - static constexpr bool WantDecompressA() { - return HWY_NATIVE_DOT_BF16 && IsF32(); - } - public: MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) @@ -741,13 +585,6 @@ class MMPerPackage { inner_tasks_(config.InnerTasks()), line_bytes_(args.env->ctx.allocator.LineBytes()) {} - // The size of `A` that will actually be used, for purposes of choosing the - // autotuning candidates. Keep in sync with the `operator()` logic below. - template - static constexpr size_t ABytes() { - return WantDecompressA() ? sizeof(BF16) : sizeof(TA); - } - // B and maybe A are decompressed several call layers lower, but not all // member functions depend on TA/TB, so pass them as an argument instead of // templating the class. @@ -755,12 +592,16 @@ class MMPerPackage { HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows) const { - if constexpr (WantDecompressA()) { + if constexpr (IsBF16()) { + // We can use a view, regardless of columns/padding, because `LoopKC` + // supports non-vector multiples. + DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); + } else { + // Always decompress. To reduce code size/compile time, we no longer + // support a separate F32 kernel; most A are already BF16. const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); DecompressA(A, A_view); DispatchOrder(parallel_policy, A_view, B, C_rows); - } else { - DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); } } @@ -937,7 +778,7 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. - const auto loop_nc = [&](const StridedViewBF& B_storage_view, + const auto loop_nc = [&](const StridedViewBF B_storage_view, const IndexRange& range_mc, const IndexRange& range_kc, const IndexRange& range_nc, @@ -1080,7 +921,7 @@ class MMPerPackage { template HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, const IndexRange& range_kc, - const StridedViewBF& B_view) const { + const StridedViewBF B_view) const { const hn::ScalableTag dbf; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); @@ -1229,7 +1070,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, if (HWY_UNLIKELY(!tuner.HasCandidates())) { // Ensure matrix dimensions match each other. HWY_ASSERT(K == B.Cols()); - HWY_ASSERT(M <= MMStorage::kMaxM); + HWY_ASSERT(M <= kMaxBatchSize); HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(N % kNR == 0); // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are @@ -1241,9 +1082,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes())); } - tuner.SetCandidates( - MMCandidates(allocator, M, K, N, MMPerPackage::ABytes(), sizeof(TC), - kMaxMR, kNR, per_key.ranges_np, env.print_config)); + tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, + kNR, per_key.ranges_np, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index 711eac1..812fe99 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -64,21 +64,19 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, class GenerateCandidates { public: GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, - size_t sizeof_TA, size_t sizeof_TC, size_t max_mr, - size_t nr, const IndexRangePartition& ranges_np, - bool print_config) + size_t sizeof_TC, size_t max_mr, size_t nr, + const IndexRangePartition& ranges_np, bool print_config) : allocator_(allocator), M_(M), K_(K), N_(N), - sizeof_TA_(sizeof_TA), sizeof_TC_(sizeof_TC), max_mr_(max_mr), nr_(nr), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line // is likely still in L1, but we expect K > 1000 and might as well round - // up to the line size. Use BF16, not sizeof_TA, because B is BF16. + // up to the line size. Both A and B are BF16. kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), nc_multiple_(allocator.StepBytes() / sizeof_TC), ranges_np_(ranges_np), @@ -176,8 +174,8 @@ class GenerateCandidates { // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. const size_t bytes_ab = - allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream)); - const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16); + allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream)); + const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); kc_max = RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); @@ -224,9 +222,9 @@ class GenerateCandidates { // packed B. We want `mc * kc` elements of A to fit in L2, alongside // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // partial. - const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes(); + const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes(); size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); - mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); + mc_max = HWY_MIN(mc_max, kMaxBatchSize); HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, M_); mc_max = hwy::RoundDownTo(mc_max, mr); @@ -340,7 +338,6 @@ class GenerateCandidates { const size_t M_; const size_t K_; const size_t N_; - const size_t sizeof_TA_; const size_t sizeof_TC_; const size_t max_mr_; @@ -358,12 +355,12 @@ class GenerateCandidates { // Facade to avoid exposing `GenerateCandidates` in the header. std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TA, - size_t sizeof_TC, size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) { - return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr, - nr, ranges_np, print_config)(); + return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, + ranges_np, print_config)(); } // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote @@ -409,7 +406,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) { char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C + row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { diff --git a/ops/matmul.h b/ops/matmul.h index 11262bc..70c7d20 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -330,10 +330,8 @@ using StridedViewD = StridedView; // Per-package storage for packed A. class MMStorage { public: - // Compile-time bounds on matrix dimensions to enable pre-allocating storage - // and reusing it across `MatMul` calls. The resulting allocations are 256 MiB - // per package and 512 MiB, respectively. - static constexpr size_t kMaxM = 4096; + // Compile-time bounds on matrix columns to enable pre-allocating storage + // and reusing it across `MatMul` calls. static constexpr size_t kMaxK = 64 * 1024; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. @@ -348,8 +346,10 @@ class MMStorage { MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { Allocator& allocator = ctx.allocator; - pkg_A_[pkg_idx].reset(new MatStorageT( - "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); + // 0.5 GiB per package. + pkg_A_[pkg_idx].reset( + new MatStorageT("pkg_A", Extents2D(kMaxBatchSize, kMaxK), + allocator, MatPadding::kOdd)); if (allocator.ShouldBind()) { const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); @@ -367,7 +367,7 @@ class MMStorage { // faster than on-the-fly when native BF16 is available: it only happens once, // not per B tile row, and the cache footprint is smaller. StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { - HWY_DASSERT(extents.rows <= kMaxM); + HWY_DASSERT(extents.rows <= kMaxBatchSize); HWY_DASSERT(extents.cols <= kMaxK); return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), extents.cols, pkg_A_[pkg_idx]->Stride()); @@ -527,8 +527,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TA, - size_t sizeof_TC, size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index dc6f559..665e337 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -120,9 +120,9 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); // Dot() uses double-precision summation. double tolerance = 20 * norm * eps_f32; - // If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to - // BF16, so add extra tolerance. - if (IsF32()) { + // If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the + // F32 to BF16, so add extra tolerance. + if (IsF32() || IsF32()) { tolerance += 2 * max_abs * eps_bf16; } diff --git a/util/basics.h b/util/basics.h index 30864b2..13d0362 100644 --- a/util/basics.h +++ b/util/basics.h @@ -33,7 +33,10 @@ namespace gcpp { // Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the // runtime `max_packages` does not exceed this. MatMul's outer per-package loop // is disabled if this is 1. -constexpr size_t kMaxPackages = 1; +HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; + +// TODO: extend to 16k after updating non_eos. +HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; From 74ffe079c4398462434646e637a3448898b4f903 Mon Sep 17 00:00:00 2001 From: Marie White Date: Wed, 3 Sep 2025 09:35:13 -0700 Subject: [PATCH 23/65] Create separate MMStorage objects per cluster. PiperOrigin-RevId: 802588625 --- ops/matmul-inl.h | 3 ++- ops/matmul.cc | 14 +++++++++++++- ops/matmul.h | 3 ++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index a9685e2..152e6ce 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -599,7 +599,8 @@ class MMPerPackage { } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. - const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); + const StridedViewBF A_view = + args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents()); DecompressA(A, A_view); DispatchOrder(parallel_policy, A_view, B, C_rows); } diff --git a/ops/matmul.cc b/ops/matmul.cc index 812fe99..83fc036 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -402,7 +402,19 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages)); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) { +MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { + // Create storage per cluster. This only applies to in-cluster parallelism. + // For nested and sequential parallelism, a single MMStorage is used. + size_t num_packages = ctx.topology.NumPackages(); + size_t num_clusters = 0; + for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) { + num_clusters += ctx.topology.NumClusters(pkg_idx); + } + storage.reserve(num_clusters); + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + storage.push_back(MMStorage(ctx)); + } + char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); diff --git a/ops/matmul.h b/ops/matmul.h index 70c7d20..e76d37b 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -369,6 +369,7 @@ class MMStorage { StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxBatchSize); HWY_DASSERT(extents.cols <= kMaxK); + HWY_DASSERT(pkg_A_[pkg_idx] != nullptr); return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), extents.cols, pkg_A_[pkg_idx]->Stride()); } @@ -733,7 +734,7 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - MMStorage storage; + std::vector storage; MMKeys keys; std::vector per_key; From 7263ab844587e0c06e6a8fe277a96c7031eb589e Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 3 Sep 2025 21:44:39 -0700 Subject: [PATCH 24/65] MatMul simplification, threading strategy improvements remove MatMul f32 special case (smaller code), types: Add u32/u64 for use by Activations move renamed ParallelismStrategy to threading_context so can pass ctx ensure worker index is unique across clusters matmul.h: const member functions for renamed policy classes (easier to call) PiperOrigin-RevId: 802848086 --- compression/types.h | 15 +- gemma/activations.h | 36 +-- gemma/attention.cc | 18 +- gemma/gemma-inl.h | 10 +- gemma/gemma.cc | 8 +- ops/bench_matmul.cc | 10 +- ops/matmul-inl.h | 472 ++++++++++++++++++++------------------- ops/matmul.cc | 9 +- ops/matmul.h | 235 +++++++++---------- ops/ops-inl.h | 32 +-- util/basics.h | 2 + util/threading.h | 56 +---- util/threading_context.h | 89 ++++++++ 13 files changed, 514 insertions(+), 478 deletions(-) diff --git a/compression/types.h b/compression/types.h index 667265a..661bc42 100644 --- a/compression/types.h +++ b/compression/types.h @@ -191,12 +191,13 @@ constexpr bool SupportsPointerArithmetic() { return !IsNuqStream(); } -// Tensor types for loading weights. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 }; +// Tensor types for loading weights. Not all of these are supported weight +// types, some are only used for `Activations`. +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", - "sfp", "nuq", "f64"}; +static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", + "nuq", "f64", "u32", "u64"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -206,6 +207,8 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(SfpStream), 4 /* NuqStream, actually 4.5 */, 8 * sizeof(double), + 8 * sizeof(uint32_t), + 8 * sizeof(uint64_t), }; static inline bool EnumValid(Type type) { @@ -226,6 +229,10 @@ Type TypeEnum() { return Type::kNUQ; } else if constexpr (hwy::IsSame()) { return Type::kF64; + } else if constexpr (hwy::IsSame()) { + return Type::kU32; + } else if constexpr (hwy::IsSame()) { + return Type::kU64; } else { HWY_DASSERT(false); return Type::kUnknown; diff --git a/gemma/activations.h b/gemma/activations.h index 14994d3..cd714ae 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -21,14 +21,14 @@ #include #include +#include #include -#include "gemma/configs.h" // ModelConfig -#include "ops/matmul.h" // MatMulEnv -#include "ops/ops.h" // CreateInvTimescale -#include "util/allocator.h" // Allocator -#include "util/basics.h" // BF16 -#include "util/mat.h" // MatStorageT +#include "gemma/configs.h" // ModelConfig +#include "ops/ops.h" // CreateInvTimescale +#include "util/basics.h" // BF16 +#include "util/mat.h" // MatStorageT +#include "util/threading_context.h" namespace gcpp { @@ -150,24 +150,28 @@ struct AttentionActivations { struct Activations { Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, - const Allocator& allocator, + ThreadingContext& ctx, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), - x(MatFactory("x", batch_size, config.model_dim, allocator)), - x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)), - logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), + x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)), + x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), + logits( + MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, - config.model_dim, allocator)), - C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)), - C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)), - ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)), + config.model_dim, ctx.allocator)), + C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, + ctx.allocator)), + C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, + ctx.allocator)), + ffw_out( + MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), - attention(config, layer_config, batch_size, seq_len, allocator, + attention(config, layer_config, batch_size, seq_len, ctx.allocator, row_ptrs), griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, - allocator) { + ctx.allocator) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. diff --git a/gemma/attention.cc b/gemma/attention.cc index bd76329..21e5019 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -19,7 +19,6 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS -#include "util/threading_context.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -29,7 +28,7 @@ #include "gemma/gemma.h" #include "gemma/weights.h" #include "util/threading.h" -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "util/threading_context.h" #include "hwy/profiler.h" // Compiles this file for multiple architectures via "foreach_target.h", to @@ -234,8 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); // Full parallelism is helpful, kAcrossClusters is insufficient. - NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, - ctx.pools, func); + HierarchicalParallelFor( + num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools, + func); } } @@ -285,9 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Apply positional encodings for K. // Note that 2D parallelism is not worth the fork/join overhead because the // tasks are very lightweight. - env.ctx.pools.Pool(0).Run( - 0, kv_heads * num_interleaved, - [&](uint64_t task, size_t thread) HWY_ATTR { + ParallelFor( + ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, + /*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); @@ -308,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim, - env.ctx.profiler, thread); + env.ctx.profiler, worker); }); } PositionalEncodingQK(kv_f32, layer_idx, layer, activations, - env.ctx.profiler, thread, pos); + env.ctx.profiler, worker, pos); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 80ec0ee..cb7ae6a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -69,9 +69,9 @@ template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismType parallelism = ParallelismType::kAcrossClusters) { + ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { using T = typename Mat::T; - ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { // Cast to correct type so type deduction works. Activation(activation, c1.Row(task), @@ -84,16 +84,16 @@ template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismType parallelism = ParallelismType::kAcrossClusters) { + ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { - ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), ctx.profiler, worker); }); } else { // No multiplier - ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), static_cast(nullptr), diff --git a/gemma/gemma.cc b/gemma/gemma.cc index fc1f238..a0949fe 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -574,7 +574,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs); + kv_cache.SeqLen(), env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); @@ -592,7 +592,7 @@ void GenerateBatchT(const ModelConfig& config, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); Activations activations(config, max_batch_size, - all_queries[0].kv_cache.SeqLen(), env.ctx.allocator, + all_queries[0].kv_cache.SeqLen(), env.ctx, env.row_ptrs); for (size_t start = 0; start < all_queries.NumQueries(); @@ -616,8 +616,8 @@ void GenerateImageTokensT(const ModelConfig& config, const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, num_tokens, num_tokens, - env.ctx.allocator, env.row_ptrs); + Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, + env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 04d535e..1be2bed 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -111,8 +111,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(b_trans, sizeof(TC), env.parallel); - BindC(C, env.parallel); + BindB(env.ctx, b_trans, sizeof(TC)); + BindC(env.ctx, C); C.AllocateAndAttachRowPtrs(env.row_ptrs); Tristate use_spinning = Tristate::kDefault; @@ -160,10 +160,10 @@ void BenchAllMatMul() { ctx.pools.PinString()); MatMulEnv env(ctx); - for (size_t batch_size : {1, 4, 128, 512}) { + for (size_t batch_size : {128, 512}) { constexpr bool kAdd = false; - BenchMatMul(batch_size, 24576, 3072, kAdd, env); - BenchMatMul(batch_size, 3072, 24576, kAdd, env); + BenchMatMul(batch_size, 24576, 3072, kAdd, env); + BenchMatMul(batch_size, 3072, 24576, kAdd, env); } PROFILER_PRINT_RESULTS(); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 152e6ce..c915b14 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -565,46 +565,204 @@ class MMKernel { } }; -// Called on the main thread with the entire N range, or by each package with -// a static partition of N. This class contains several variants of the -// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. -// Its member variables avoid long argument lists in Do*(). -class MMPerPackage { - public: - MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, - size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) - : args_(args), - pkg_idx_(pkg_idx), - cluster_idx_(cluster_idx), - range_np_(range_np), - mr_(config.MR()), - ranges_mc_(config.RangesOfMC(A.rows)), - ranges_kc_(config.RangesOfKC(A.cols)), - ranges_nc_(config.RangesOfNC(range_np)), - order_(config.Order()), - inner_tasks_(config.InnerTasks()), - line_bytes_(args.env->ctx.allocator.LineBytes()) {} +// Miscellaneous stateless helper functions. +struct MMImpl { + // Returns existing entry for the given key or -1. + static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { + const hwy::Span all_keys = keys.Keys(); + // TODO: SIMD scan + for (size_t i = 0; i < all_keys.size(); ++i) { + if (all_keys[i] == key) return static_cast(i); + } + return -1; + } - // B and maybe A are decompressed several call layers lower, but not all - // member functions depend on TA/TB, so pass them as an argument instead of - // templating the class. - template - HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, - const MatPtrT& A, const MatPtrT& B, - RowPtrs C_rows) const { + static size_t Worker(const MMArgs& args) { + return args.options.cluster_idx * + args.env->ctx.pools.MaxWorkersPerCluster(); + } + + template + static void DispatchParallelism(ParallelismStrategy parallelism, + const Func& func) { + switch (parallelism) { + case ParallelismStrategy::kHierarchical: + return func(MMParallelHierarchical()); + case ParallelismStrategy::kNone: + return func(MMParallelNone()); + case ParallelismStrategy::kWithinCluster: + return func(MMParallelWithinCluster()); + default: + HWY_UNREACHABLE; + } + } + + // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + static HWY_NOINLINE void DoDecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMParA par_a, const MMArgs& args) { + const IndexRange all_M(0, A.Rows()); + const IndexRange all_K(0, A.Cols()); + HWY_DASSERT(all_K.Num() == A_view.Cols()); + + const hn::ScalableTag dbf; + const size_t NBF = hn::Lanes(dbf); + + static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA"); + + const auto do_range = + [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) + HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, args); + + const size_t col0 = range_K.begin(); + const size_t cols = range_K.Num(); + // Must be a vector multiple, or the last range before row + // padding, otherwise `DecompressAndZeroPad` overwrites neighbors. + HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); + for (size_t row_a : range_M) { + const PackedSpan from = + MakeSpan(A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; + DecompressAndZeroPad(dbf, from, 0, to, cols); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + }; + + switch (par_a) { + case MMParA::kNone: + do_range(all_M, all_K, MMImpl::Worker(args)); + break; + + case MMParA::kK1: + case MMParA::kK2: + case MMParA::kK4: { + const size_t inner_tasks = static_cast(par_a); + // At least one vector, otherwise DecompressAndZeroPad will add + // padding, which might overwrite neighboring tasks. Also a whole cache + // line to avoid false sharing. + const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16)); + + DispatchParallelism( + args.options.parallelism, [&](const auto& parallel) { + parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks, + args.options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); + }); + break; + } + case MMParA::kM: + DispatchParallelism( + args.options.parallelism, [&](const auto& parallel) { + parallel.ForRangeMC( + args.env->ctx, all_M, args.options.cluster_idx, + [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, worker); + }); + }); + break; + } + } + + // Autotuning wrapper for `DoDecompressA`. + static HWY_INLINE void DecompressA(const MatPtrT& A, + const StridedViewBF A_view, + const MMArgs& args) { + MMAutoTune& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0]; + + if (HWY_LIKELY(autotune.Best())) { + return DoDecompressA(A, A_view, *autotune.Best(), args); + } + + // First call: generate candidates. + if (HWY_UNLIKELY(!autotune.HasCandidates())) { + const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; + std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, + other}; + autotune.SetCandidates(candidates); + } + + const MMParA& par_a = autotune.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DoDecompressA(A, A_view, par_a, args); + const uint64_t t1 = + args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); + if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) { + fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), + static_cast(min_elapsed) / + hwy::platform::InvariantTicksPerSecond() * 1E6); + } + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + template + static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, + const MMArgs& args) { if constexpr (IsBF16()) { // We can use a view, regardless of columns/padding, because `LoopKC` // supports non-vector multiples. - DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); + return View(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. const StridedViewBF A_view = - args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents()); - DecompressA(A, A_view); - DispatchOrder(parallel_policy, A_view, B, C_rows); + args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0, + A.Extents()); + DecompressA(A, A_view, args); + return A_view; } } +}; + +// Contains several variants of the outer M/N/K loops, and calls `A2C0` which +// loops over the inner KC and MC. Member variables avoid long argument lists. +class MMState { + public: + MMState(const Extents2D A, const MMArgs& args, const MMConfig& config) + : args_(args), + range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)), + mr_(config.MR()), + ranges_mc_(config.RangesOfMC(A.rows)), + ranges_kc_(config.RangesOfKC(A.cols)), + ranges_nc_(config.RangesOfNC(range_np_)), + order_(config.Order()), + inner_tasks_(config.InnerTasks()) { + HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); + } + + // Called from `MatMul` from two places: either with the next autotune config, + // or with the best config. + template + HWY_NOINLINE void DispatchParallelism(const StridedViewBF A, + const MatPtrT& B, + RowPtrs C_rows) const { + /* Disabled due to unknown thread-safety issue: + static const auto zone = + args_.env->ctx.profiler.AddZone("MM.DispatchParallelism"); + PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone); + */ + + MMImpl::DispatchParallelism( + args_.options.parallelism, + [&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); }); + } private: // Compute size of per-worker storage for `kNR` row ranges of B. Stack @@ -616,40 +774,32 @@ class MMPerPackage { // Granularity of `ForNP`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. size_t MultipleNP(size_t sizeof_TC) const { - return HWY_MAX(kNR, line_bytes_ / sizeof_TC); + return HWY_MAX(kNR, args_.line_bytes / sizeof_TC); } - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - - // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. - template - HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, - const StridedView A, const MatPtrT& B, - RowPtrs C_rows) const { + // B is decompressed several call layers lower, but not all member functions + // depend on `TB`, so pass it as an argument instead of templating the class. + template + HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(parallel_policy, A, B, C_rows); + return DoNT(parallel_policy, A, B, C_rows); case MMOrder::kNT_K: - return DoNT_K(parallel_policy, A, B, C_rows); + return DoNT_K(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(parallel_policy, A, B, C_rows); + return DoNT_MT(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(parallel_policy, A, B, C_rows); + return DoNT_MT_K(parallel_policy, A, B, C_rows); default: HWY_UNREACHABLE; } } // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -657,14 +807,14 @@ class MMPerPackage { const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); - const StridedView A_view = A.View(range_M.begin(), 0, K); + const StridedViewBF A_view = A.View(range_M.begin(), 0, K); const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); + Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); // Similar to `loop_nc` below, but here we hoisted `A_view`. - MMParallelPolicyT::ForNP( + parallel.ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, cluster_idx_, + args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -683,8 +833,8 @@ class MMPerPackage { } // Single M range, parallel N, sequential K. Sets C, then accumulates. - template - HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -697,11 +847,11 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedView A_view = + const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); const StridedViewBF B_storage_view( B_storage, kc, - Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); + Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes)); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -711,9 +861,9 @@ class MMPerPackage { } }; - MMParallelPolicyT::ForNP( + parallel.ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, cluster_idx_, + args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -733,26 +883,26 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); + Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. - MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, + parallel.ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); - const StridedView A_view = A.View(range_mc.begin(), 0, K); + const StridedViewBF A_view = A.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS const StridedViewBF B_storage_view(B_storage, K, B_stride); @@ -768,14 +918,14 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView A, + template + HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); const size_t B_stride = - Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_); + Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. @@ -785,7 +935,7 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedView A_view = + const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); @@ -795,8 +945,8 @@ class MMPerPackage { C_rows); } }; // loop_nc - MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, + parallel.ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -816,106 +966,6 @@ class MMPerPackage { }); } - // Decompresses all `M x K` from `A` into padded BF16 `A_view`. - template - HWY_NOINLINE void DoDecompressA(const MatPtrT& A, - const StridedViewBF A_view, - MMParA par_a) const { - const IndexRange all_M(0, A.Rows()); - const IndexRange all_K(0, A.Cols()); - HWY_DASSERT(all_K.Num() == A_view.Cols()); - - const hn::ScalableTag dbf; - const size_t NBF = hn::Lanes(dbf); - - static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA"); - - const auto do_range = [&](const IndexRange& range_M, - const IndexRange& range_K, - size_t worker) HWY_ATTR { - MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - const size_t col0 = range_K.begin(); - const size_t cols = range_K.Num(); - // Must be a vector multiple, or the last range before row padding, - // otherwise `DecompressAndZeroPad` overwrites neighbors. - HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); - for (size_t row_a : range_M) { - const PackedSpan from = - MakeSpan(A.Row(row_a) + col0, cols); - BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; - DecompressAndZeroPad(dbf, from, 0, to, cols); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - }; - - switch (par_a) { - case MMParA::kNone: - do_range(all_M, all_K, /*worker=*/0); - break; - case MMParA::kK1: - case MMParA::kK2: - case MMParA::kK4: { - const size_t inner_tasks = static_cast(par_a); - // At least one vector, otherwise DecompressAndZeroPad will add - // padding, which might overwrite neighboring tasks. Also a whole cache - // line to avoid false sharing. - const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); - - MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks, - pkg_idx_, cluster_idx_, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); - break; - } - case MMParA::kM: - MMParallelPolicyT::ForRangeMC( - args_.env->ctx, all_M, pkg_idx_, cluster_idx_, - [&](size_t row_a, size_t worker) { - do_range(IndexRange(row_a, row_a + 1), all_K, worker); - }); - break; - } - } - - // Autotuning wrapper for `DoDecompressA`. - template - HWY_INLINE void DecompressA(const MatPtrT& A, - const StridedViewBF A_view) const { - MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; - - if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, *autotune.Best()); - } - - // First call: generate candidates. - if (HWY_UNLIKELY(!autotune.HasCandidates())) { - const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; - std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, - other}; - autotune.SetCandidates(candidates); - } - - const MMParA& par_a = autotune.NextConfig(); - const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, par_a); - const uint64_t t1 = - args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); - if (HWY_UNLIKELY(args_.env->print_measurement && autotune.ShouldPrint())) { - fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), - static_cast(min_elapsed) / - hwy::platform::InvariantTicksPerSecond() * 1E6); - } - } - // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // thanks to its large table lookups, and less so on other targets. @@ -928,7 +978,7 @@ class MMPerPackage { // Neither A nor B require padding because `LoopKC` handles remainders. if constexpr (hwy::IsSame()) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); + return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num()); } const PackedSpan B_span = B.PaddedSpan(); @@ -951,8 +1001,6 @@ class MMPerPackage { } const MMArgs args_; // copy for locality - const size_t pkg_idx_; - const size_t cluster_idx_; // 0 for sequential and nested. const IndexRange range_np_; // From MMConfig: @@ -962,52 +1010,7 @@ class MMPerPackage { const IndexRangePartition ranges_nc_; const MMOrder order_; const size_t inner_tasks_; - const size_t line_bytes_; -}; // MMPerPackage - -// Stateless, wraps member functions. -struct MMImpl { - // Returns existing entry for the given key or -1. - static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { - const hwy::Span all_keys = keys.Keys(); - // TODO: SIMD scan - for (size_t i = 0; i < all_keys.size(); ++i) { - if (all_keys[i] == key) return static_cast(i); - } - return -1; - } - - // Called from `MatMul` from two places: either with the next autotune config, - // or with the best config. - template - static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args, - const MMConfig& config, MMOptions options) { - PROFILER_ZONE("MM.DoMatMul"); - const size_t pkg_idx = 0; - HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - - switch (options.parallelism_type) { - case ParallelismType::kNested: - HWY_DASSERT(options.cluster_idx == 0); - MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, - range_np)(MMNestedParallelPolicy(), A, B, C_rows); - break; - case ParallelismType::kSequential: - MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, - range_np)(MMSequentialPolicy(), A, B, C_rows); - case ParallelismType::kWithinCluster: - MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, - range_np)(MMClusterParallelPolicy(), A, B, C_rows); - break; - default: - HWY_ABORT("Parallelism type %s not implemented.", - static_cast(options.parallelism_type)); - break; - } - } -}; +}; // MMState // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // @@ -1033,17 +1036,19 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { - RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); + HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); + RowPtrs C_rows = + GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, env.keys); + intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { - env.keys.Append(key, allocator); + env.keys[options.cluster_idx].Append(key, allocator); size_t max_packages = kMaxPackages; // For low-batch, multiple sockets only help if binding is enabled. @@ -1052,16 +1057,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, } // invalidates `MMAutoTune::Best()` - index = env.per_key.size(); - env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); + std::vector& stored_keys = env.per_key[options.cluster_idx]; + index = stored_keys.size(); + stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); } - MMPerKey& per_key = env.per_key[index]; + MMPerKey& per_key = env.per_key[options.cluster_idx][index]; MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), - add); + add, options); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options); + const MMState state(A.Extents(), args, *tuner.Best()); + const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); + state.DispatchParallelism(A_view, B, C_rows); return &per_key; } @@ -1089,7 +1097,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg, options); + MMState state(A.Extents(), args, cfg); + const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); + state.DispatchParallelism(A_view, B, C_rows); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.cc b/ops/matmul.cc index 83fc036..75b37a2 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -405,20 +405,15 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { // Create storage per cluster. This only applies to in-cluster parallelism. // For nested and sequential parallelism, a single MMStorage is used. - size_t num_packages = ctx.topology.NumPackages(); - size_t num_clusters = 0; - for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) { - num_clusters += ctx.topology.NumClusters(pkg_idx); - } + const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); storage.reserve(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { storage.push_back(MMStorage(ctx)); + row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - - row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { diff --git a/ops/matmul.h b/ops/matmul.h index e76d37b..16cb51c 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -58,147 +58,127 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, size_t N, size_t sizeof_TC, size_t nr); struct MMOptions { - ParallelismType parallelism_type = ParallelismType::kNested; - uint8_t cluster_idx = 0; + uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. + ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; }; -struct MMSequentialPolicy { - template - static void ForPkg(ThreadingContext& ctx, const size_t max_packages, - const Func& func) { - func(/*pkg_idx=*/0); - } +// Policy classes for parallelism, implementing some of `ParallelismStrategy`. +struct MMParallelNone { template - static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - size_t cluster_idx, const Func& func) { + void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + - cluster_idx * ctx.pools.MaxWorkersPerCluster(); - func(range_np, base_idx); + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + func(range_np, worker); } template - static void ForRangesMC_NC(ThreadingContext& ctx, - const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, - size_t pkg_idx, size_t cluster_idx, - const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t cluster_idx, + const Func& func) const { + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { const IndexRange range_mc = ranges_mc.Range(i); for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) { const IndexRange range_nc = ranges_nc.Range(j); - func(range_mc, range_nc, base_idx); + func(range_mc, range_nc, worker); } } } template - static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, size_t cluster_idx, const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t cluster_idx, const Func& func) const { + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { - func(row_a, base_idx); + func(row_a, worker); } } }; -struct MMClusterParallelPolicy { +struct MMParallelWithinCluster { template - static void ForPkg(ThreadingContext& ctx, const size_t max_packages, - const Func& func) { - func(/*pkg_idx=*/0); - } - - template - static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - size_t cluster_idx, const Func& func) { + void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); ParallelizeOneRange(worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, worker); + func(worker_range, base + worker); }); } template - static void ForRangesMC_NC(ThreadingContext& ctx, - const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, - size_t pkg_idx, size_t cluster_idx, - const Func& func) { + void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t cluster_idx, + const Func& func) const { + const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { ParallelizeOneRange(ranges_nc, cluster, - [&](const IndexRange& range_nc, size_t thread) { - func(ranges_mc.Range(0), range_nc, thread); + [&](const IndexRange& range_nc, size_t worker) { + func(ranges_mc.Range(0), range_nc, base + worker); }); } else { ParallelizeTwoRanges( ranges_mc, ranges_nc, cluster, [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t thread) { func(range_mc, range_nc, thread); }); + size_t worker) { func(range_mc, range_nc, base + worker); }); } } template - static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, size_t cluster_idx, const Func& func) { + void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t cluster_idx, const Func& func) const { + const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - cluster.Run(range_mc.begin(), range_mc.end(), - [&](uint64_t row_a, size_t thread) { func(row_a, thread); }); + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + + cluster.Run( + range_mc.begin(), range_mc.end(), + [&](uint64_t row_a, size_t worker) { func(row_a, base + worker); }); } }; -struct MMNestedParallelPolicy { - template - static void ForPkg(ThreadingContext& ctx, const size_t max_packages, - const Func& func) { - if constexpr (kMaxPackages > 1) { - ctx.pools.AllPackages().Run( - 0, HWY_MIN(max_packages, ctx.pools.NumPackages()), - [&](uint64_t task, size_t pkg_idx) { - HWY_DASSERT(task == pkg_idx); - (void)task; - func(pkg_idx); - }); - } else { - func(/*pkg_idx=*/0); - } - } - +struct MMParallelHierarchical { // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. - // `cluster_idx` is not used here as all clusters within a package are used. template - static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - size_t /*cluster_idx*/, const Func& func) { + void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, + HWY_MAYBE_UNUSED size_t caller_cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + HWY_DASSERT(caller_cluster_idx == 0); // Single cluster: parallel-for over static partition of `range_np`. + const size_t pkg_idx = 0; hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0); + const size_t cluster_idx = 0; + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); return ParallelizeOneRange( worker_ranges, cluster, - [&](const IndexRange& worker_range, size_t thread) { - func(worker_range, pkg_base + thread); + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, worker); }); } @@ -210,28 +190,29 @@ struct MMNestedParallelPolicy { [&](const IndexRange& nx_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = - pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); ParallelizeOneRange( worker_ranges, cluster, - [&](const IndexRange& worker_range, size_t thread) { - func(worker_range, cluster_base + thread); + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, cluster_base + worker); }); }); } // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // rows). Calls `func(range_mc, range_nc, worker)`. - // `cluster_idx` is not used here as all clusters within a package are used. template - static void ForRangesMC_NC(ThreadingContext& ctx, - const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, - size_t pkg_idx, size_t /*cluster_idx*/, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + HWY_MAYBE_UNUSED size_t caller_cluster_idx, + const Func& func) const { + const size_t pkg_idx = 0; + HWY_DASSERT(caller_cluster_idx == 0); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. const size_t num_clusters = all_clusters.NumWorkers(); @@ -243,16 +224,14 @@ struct MMNestedParallelPolicy { // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( - ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) { - func(ranges_mc.Range(0), range_nc, pkg_base + thread); + ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) { + func(ranges_mc.Range(0), range_nc, worker); }); } else { return ParallelizeTwoRanges( ranges_mc, ranges_nc, cluster, [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t thread) { - func(range_mc, range_nc, pkg_base + thread); - }); + size_t worker) { func(range_mc, range_nc, worker); }); } } @@ -262,25 +241,23 @@ struct MMNestedParallelPolicy { ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { const size_t cluster_base = - pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + cluster_idx * ctx.pools.MaxWorkersPerCluster(); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange(ranges_mc, cluster, - [&](const IndexRange& range_mc, size_t thread) { - func(range_mc, range_nc, cluster_base + thread); + [&](const IndexRange& range_mc, size_t worker) { + func(range_mc, range_nc, cluster_base + worker); }); }); } // Calls `func(row_a, worker)` in parallel. - // `cluster_idx` is not used here as all clusters within a package are used. template - static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, size_t /*cluster_idx*/, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); - ctx.pools.Pool(pkg_idx).Run( - range_mc.begin(), range_mc.end(), - [&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); }); + void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t caller_cluster_idx, const Func& func) const { + HierarchicalParallelFor(range_mc.Num(), ctx.pools, + [&](size_t task, size_t worker) { + func(range_mc.begin() + task, worker); + }); } }; @@ -340,27 +317,22 @@ class MMStorage { // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). MMStorage(ThreadingContext& ctx) { - // Per-package allocation so each can decompress A into its own copy. - // Must be padded, see `DoDecompressA`. - // Default to nested parallel policy. - MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { - Allocator& allocator = ctx.allocator; + Allocator& allocator = ctx.allocator; + const size_t pkg_idx = 0; - // 0.5 GiB per package. - pkg_A_[pkg_idx].reset( - new MatStorageT("pkg_A", Extents2D(kMaxBatchSize, kMaxK), - allocator, MatPadding::kOdd)); + // 0.5 GiB per package. Must be padded, see `DoDecompressA`. + pkg_A_[pkg_idx].reset(new MatStorageT( + "pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd)); - if (allocator.ShouldBind()) { - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * - pkg_A_[pkg_idx]->ElementBytes(); - bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); - if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { - HWY_WARN("Failed to bind memory for package %zu", pkg_idx); - } + if (allocator.ShouldBind()) { + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); + size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * + pkg_A_[pkg_idx]->ElementBytes(); + bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); + if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { + HWY_WARN("Failed to bind memory for package %zu", pkg_idx); } - }); + } } // Returns per-package matrix view. Converting A=F32 to BF16 up-front is @@ -735,16 +707,18 @@ struct MatMulEnv { bool print_best = false; std::vector storage; - MMKeys keys; - std::vector per_key; + MMKeys keys[kMaxClusters]; + std::vector per_key[kMaxClusters]; // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV // writes to differing KV positions per query / output row. - // The first entry is sufficient for any C argument, but also potentially - // overwritten by each MatMul. Subsequent entries are precomputed for tensors - // and not overwritten. Per-tensor allocations make it likelier that asan - // detects bugs such as use after free, overrun, and dangling references. + // The first `num_clusters` entries are sufficient for any C argument, and + // must be indexed by `options.cluster_idx`. Note that they are potentially + // overwritten by each `MatMul`. Subsequent entries are for specific tensors + // and only written once by their allocator. A per-tensor allocation makes it + // likelier that asan detects bugs such as use after free, overrun, and + // dangling references. std::vector> row_ptrs; }; @@ -752,14 +726,21 @@ struct MatMulEnv { // Reduces register pressure compared to individual values/references. struct MMArgs { MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add) - : env(&env), per_key(&per_key), scale(scale), add(add) {} + const float* HWY_RESTRICT add, MMOptions options) + : env(&env), + per_key(&per_key), + scale(scale), + add(add), + options(options), + line_bytes(env.ctx.allocator.LineBytes()) {} MatMulEnv* env; MMPerKey* per_key; double scale; const float* HWY_RESTRICT add; + MMOptions options; + size_t line_bytes; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0173ee8..19a39aa 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -501,12 +501,12 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor( - ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools, - cluster_idx, [&](uint64_t token_idx, size_t worker) { - RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), - out.Row(token_idx), activations.Cols(), ctx.profiler, worker); - }); + ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, + cluster_idx, [&](uint64_t token_idx, size_t worker) { + RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), + out.Row(token_idx), activations.Cols(), ctx.profiler, + worker); + }); }); } @@ -517,12 +517,12 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor( - ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx, - [&](uint64_t token_idx, size_t worker) { - RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), - inout.Cols(), ctx.profiler, worker); - }); + ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, + [&](uint64_t token_idx, size_t worker) { + RMSNormInplace(weights_t->PackedScale1(), + inout.Row(token_idx), inout.Cols(), + ctx.profiler, worker); + }); }); } @@ -548,8 +548,8 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, ThreadingContext& ctx, size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); - ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools, - cluster_idx, [&](uint64_t token_idx, size_t worker) { + ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, + [&](uint64_t token_idx, size_t worker) { AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx.profiler, worker); }); @@ -782,8 +782,8 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; - ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools, - cluster_idx, [&](uint64_t task, size_t worker) { + ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, + [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); diff --git a/util/basics.h b/util/basics.h index 13d0362..7cdc17c 100644 --- a/util/basics.h +++ b/util/basics.h @@ -35,6 +35,8 @@ namespace gcpp { // is disabled if this is 1. HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; +HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink + // TODO: extend to 16k after updating non_eos. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; diff --git a/util/threading.h b/util/threading.h index ef4f1c7..5dde114 100644 --- a/util/threading.h +++ b/util/threading.h @@ -326,7 +326,8 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // over clusters of ONE package, then within each cluster. template -void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { +void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools, + const Func& func) { // Even if there are multiple packages, we only use the first. const size_t pkg_idx = 0; @@ -356,59 +357,6 @@ void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { }); } -// Which pool(s) to use for parallelizing: -enum class ParallelismType : uint8_t { - // None: single-threaded loop on the calling thread. - kSequential, - // One thread per cluster within the first package; or one per core if there - // is only one cluster. Use for few or lightweight tasks, or to maximize - // memory bandwidth availability. - kAcrossClusters, - // All cores within the cluster identified by `cluster_idx`. Use if already - // within a `kAcrossClusters` parallel-for, or if latency is more important - // than memory bandwidth. - kWithinCluster, - // First statically partitions `kAcrossClusters`, then `kWithinCluster`. This - // utilizes all cores, but has higher fork-join overhead (two barriers); use - // if there are many or heavy tasks. - kNested, -}; - -// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the -// number/type of workers determined by `parallelism`. `cluster_idx` is only -// used if `parallelism == kWithinCluster`. -template -void ParallelFor(ParallelismType parallelism, size_t num_tasks, - NestedPools& pools, size_t cluster_idx, const Func& func) { - if (cluster_idx != 0) { - // If already running across clusters, must not use across-cluster modes. - HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters && - parallelism != ParallelismType::kNested); - } - - const size_t pkg_idx = 0; - switch (parallelism) { - case ParallelismType::kSequential: - for (size_t task = 0; task < num_tasks; ++task) { - func(task, /*worker=*/0); - } - return; - - case ParallelismType::kAcrossClusters: - return pools.Pool(pkg_idx).Run( - 0, num_tasks, - [&](uint64_t task, size_t worker) { func(task, worker); }); - - case ParallelismType::kWithinCluster: - return pools.Cluster(pkg_idx, cluster_idx) - .Run(0, num_tasks, - [&](uint64_t task, size_t worker) { func(task, worker); }); - - case ParallelismType::kNested: - return NestedParallelFor(num_tasks, pools, func); - } -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_context.h b/util/threading_context.h index d4fdc17..847ce81 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -116,6 +116,95 @@ struct ThreadingContext { NestedPools pools; }; +// Describes the strategy for distributing parallel work across cores. +enum class ParallelismStrategy : uint8_t { + // Execute using a single-threaded loop on the calling thread. The `worker` + // index passed to the user's `Func` is unique across clusters. + kNone, + // One thread per cluster within the first package. The `worker` index passed + // to the user's `Func` is a `cluster_idx <= NumClusters()`. Some CPUs may + // only have a single cluster, hence `Func` should also contain a nested + // `ParallelFor` with `kWithinCluster`. + kAcrossClusters, + // All cores within the cluster identified by `cluster_idx`. The `worker` + // index passed to the user's `Func` is unique across clusters. Choose this + // strategy if already within a `ParallelFor` call with `kAcrossClusters`, + // or latency is more important than memory bandwidth. + kWithinCluster, + // Equivalent to `kAcrossClusters` if there are multiple clusters, otherwise + // `kWithinCluster`. Use for few or lightweight tasks (this only uses a + // single pool and barrier), or to maximize memory bandwidth availability. + kFlat, + // First statically partitions `kAcrossClusters`, then `kWithinCluster`. This + // utilizes all cores, but has higher fork-join overhead (two barriers); use + // if there are many or heavy tasks. + kHierarchical, +}; + +// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the +// number/type of workers determined by `parallelism`. `cluster_idx` is for +// `parallelism == kWithinCluster`, and should be 0 if unknown. +template +void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, + ThreadingContext& ctx, size_t cluster_idx, const Func& func) { + HWY_DASSERT(ctx.topology.NumPackages() == 1); + const size_t pkg_idx = 0; + + HWY_DASSERT(cluster_idx < ctx.topology.NumClusters(pkg_idx)); + if (cluster_idx != 0) { + // If already running across clusters, only use within-cluster modes. + HWY_DASSERT(parallelism == ParallelismStrategy::kNone || + parallelism == ParallelismStrategy::kWithinCluster); + } + + switch (parallelism) { + case ParallelismStrategy::kNone: { + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + for (size_t task = 0; task < num_tasks; ++task) { + func(task, worker); + } + return; + } + + case ParallelismStrategy::kAcrossClusters: + return ctx.pools.AllClusters(pkg_idx).Run( + 0, num_tasks, + [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); + + case ParallelismStrategy::kWithinCluster: { + // Ensure the worker argument is unique across clusters, because it is + // used for TLS indexing for example in profiler.h. + const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + return ctx.pools.Cluster(pkg_idx, cluster_idx) + .Run(0, num_tasks, [&](uint64_t task, size_t worker) { + func(task, base + worker); + }); + } + + case ParallelismStrategy::kFlat: { + // Check for single cluster; if not, we must compute `cluster_base` for + // consistent and non-overlapping worker indices. + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); + const size_t num_clusters = all_clusters.NumWorkers(); + if (num_clusters == 1) { + return ctx.pools.Cluster(pkg_idx, cluster_idx) + .Run(0, num_tasks, + [&](uint64_t task, size_t worker) { func(task, worker); }); + } + + return ctx.pools.AllClusters(pkg_idx).Run( + 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { + const size_t worker = + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + func(task, worker); + }); + } + + case ParallelismStrategy::kHierarchical: + return HierarchicalParallelFor(num_tasks, ctx.pools, func); + } +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_ From 4be479972758b9064db24d1b629c23c5fc227f97 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Sep 2025 03:32:35 -0700 Subject: [PATCH 25/65] Remove kMaxPackages and per-package-related code matmul: remove kMaxClusters, dynamic allocation PiperOrigin-RevId: 802950348 --- gemma/activations.h | 8 +-- ops/dot_test.cc | 1 - ops/matmul-inl.h | 116 +++++++++++++++++++--------------- ops/matmul.cc | 121 ++++++++++------------------------- ops/matmul.h | 133 ++++++++++++++++----------------------- ops/matmul_test.cc | 41 +++++------- util/allocator.cc | 5 +- util/allocator.h | 2 +- util/basics.h | 7 --- util/mat.h | 1 + util/threading_context.h | 12 ++-- 11 files changed, 177 insertions(+), 270 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index cd714ae..63b3153 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -21,7 +21,6 @@ #include #include -#include #include #include "gemma/configs.h" // ModelConfig @@ -62,11 +61,12 @@ struct AttentionActivations { // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. static inline float ChooseQueryScale(const ModelConfig& config) { + const LayerConfig& layer_config = config.layer_configs[0]; if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / sqrtf(static_cast(config.model_dim / - config.layer_configs[0].heads)); + return 1.0f / + sqrtf(static_cast(config.model_dim / layer_config.heads)); // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); + return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); } AttentionActivations( diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 2c0ae3a..8afb220 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1101,7 +1101,6 @@ void TestAllDot() { // Limit workers because we only support `kMaxWorkers`. ThreadingArgs threading_args; - threading_args.max_packages = 1; threading_args.max_clusters = 1; threading_args.max_lps = kMaxWorkers - 1; ThreadingContext ctx(threading_args); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index c915b14..8b9c011 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -570,10 +570,29 @@ struct MMImpl { // Returns existing entry for the given key or -1. static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { const hwy::Span all_keys = keys.Keys(); - // TODO: SIMD scan - for (size_t i = 0; i < all_keys.size(); ++i) { - if (all_keys[i] == key) return static_cast(i); + + const hn::ScalableTag d; + using V = hn::Vec; + const V broadcasted = Set(d, key); + const size_t N = hn::Lanes(d); + + size_t i = 0; + if (all_keys.size() >= N) { + for (; i <= all_keys.size() - N; i += N) { + const intptr_t pos = hn::FindFirstTrue( + d, hn::Eq(broadcasted, hn::LoadU(d, &all_keys[i]))); + if (pos >= 0) return static_cast(i) + pos; + } } + + const size_t remaining = all_keys.size() - i; + if (HWY_LIKELY(remaining > 0)) { + HWY_DASSERT(remaining < N); + const V v = hn::LoadN(d, &all_keys[i], remaining); + const intptr_t pos = hn::FindFirstTrue(d, hn::Eq(broadcasted, v)); + if (pos >= 0) return static_cast(i) + pos; + } + return -1; } @@ -582,6 +601,15 @@ struct MMImpl { args.env->ctx.pools.MaxWorkersPerCluster(); } + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + template static void DispatchParallelism(ParallelismStrategy parallelism, const Func& func) { @@ -651,11 +679,11 @@ struct MMImpl { DispatchParallelism( args.options.parallelism, [&](const auto& parallel) { - parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks, - args.options.cluster_idx, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); + parallel.ForN(args.env->ctx, all_K, multiple_K, inner_tasks, + args.options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); }); break; } @@ -676,7 +704,7 @@ struct MMImpl { static HWY_INLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view, const MMArgs& args) { - MMAutoTune& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0]; + MMAutoTune& autotune = args.per_key->autotune_par_a; if (HWY_LIKELY(autotune.Best())) { return DoDecompressA(A, A_view, *autotune.Best(), args); @@ -703,15 +731,6 @@ struct MMImpl { } } - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - template static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, const MMArgs& args) { @@ -723,8 +742,7 @@ struct MMImpl { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. const StridedViewBF A_view = - args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0, - A.Extents()); + args.env->storage[args.options.cluster_idx].A(A.Extents()); DecompressA(A, A_view, args); return A_view; } @@ -735,17 +753,16 @@ struct MMImpl { // loops over the inner KC and MC. Member variables avoid long argument lists. class MMState { public: - MMState(const Extents2D A, const MMArgs& args, const MMConfig& config) + MMState(const Extents2D A, const size_t B_rows, const MMArgs& args, + const MMConfig& config) : args_(args), - range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)), + range_n_(0, B_rows), mr_(config.MR()), ranges_mc_(config.RangesOfMC(A.rows)), ranges_kc_(config.RangesOfKC(A.cols)), - ranges_nc_(config.RangesOfNC(range_np_)), + ranges_nc_(config.RangesOfNC(B_rows)), order_(config.Order()), - inner_tasks_(config.InnerTasks()) { - HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); - } + inner_tasks_(config.InnerTasks()) {} // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. @@ -768,12 +785,12 @@ class MMState { // Compute size of per-worker storage for `kNR` row ranges of B. Stack // allocation avoids passing a worker index. static constexpr size_t B_stride_max_ = - MMStorage::kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16); + kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16); static constexpr size_t B_storage_max_ = kNR * B_stride_max_; - // Granularity of `ForNP`. B rows produce C columns, so we + // Granularity of `ForN`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. - size_t MultipleNP(size_t sizeof_TC) const { + size_t MultipleN(size_t sizeof_TC) const { return HWY_MAX(kNR, args_.line_bytes / sizeof_TC); } @@ -812,8 +829,8 @@ class MMState { Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); // Similar to `loop_nc` below, but here we hoisted `A_view`. - parallel.ForNP( - args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + parallel.ForN( + args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -861,8 +878,8 @@ class MMState { } }; - parallel.ForNP( - args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + parallel.ForN( + args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -881,7 +898,7 @@ class MMState { }); } - // Parallel loops over mc/nc blocks of M/range_np, single K. + // Parallel loops over mc/nc blocks of M/range_n, single K. // Fills `mc x nc` sections of C directly, in parallel. template HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A, @@ -923,7 +940,7 @@ class MMState { const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); - HWY_DASSERT(kc_max <= MMStorage::kMaxKC); + HWY_DASSERT(kc_max <= kMaxKC); const size_t B_stride = Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, for when the M/N loops are @@ -1002,7 +1019,7 @@ class MMState { const MMArgs args_; // copy for locality - const IndexRange range_np_; + const IndexRange range_n_; // From MMConfig: const size_t mr_; const IndexRangePartition ranges_mc_; @@ -1036,38 +1053,33 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { + const Allocator& allocator = env.ctx.allocator; HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); + MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx]; RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); - const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]); + intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { - env.keys[options.cluster_idx].Append(key, allocator); - - size_t max_packages = kMaxPackages; - // For low-batch, multiple sockets only help if binding is enabled. - if (!allocator.ShouldBind() && M <= 4) { - max_packages = 1; - } + per_cluster.keys.Append(key, allocator); // invalidates `MMAutoTune::Best()` - std::vector& stored_keys = env.per_key[options.cluster_idx]; - index = stored_keys.size(); - stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); + std::vector& per_keys = per_cluster.per_key; + index = per_keys.size(); + per_keys.push_back(MMPerKey()); } - MMPerKey& per_key = env.per_key[options.cluster_idx][index]; + MMPerKey& per_key = per_cluster.per_key[index]; MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add, options); if (HWY_LIKELY(tuner.Best())) { - const MMState state(A.Extents(), args, *tuner.Best()); + const MMState state(A.Extents(), B.Rows(), args, *tuner.Best()); const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); state.DispatchParallelism(A_view, B, C_rows); return &per_key; @@ -1092,12 +1104,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, } tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, - kNR, per_key.ranges_np, env.print_config)); + kNR, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMState state(A.Extents(), args, cfg); + MMState state(A.Extents(), B.Rows(), args, cfg); const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); state.DispatchParallelism(A_view, B, C_rows); const uint64_t t1 = diff --git a/ops/matmul.cc b/ops/matmul.cc index 75b37a2..35887a5 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -21,7 +21,6 @@ #include #include -#include #include #include "util/allocator.h" @@ -65,7 +64,7 @@ class GenerateCandidates { public: GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, size_t sizeof_TC, size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, bool print_config) + bool print_config) : allocator_(allocator), M_(M), K_(K), @@ -79,7 +78,6 @@ class GenerateCandidates { // up to the line size. Both A and B are BF16. kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), nc_multiple_(allocator.StepBytes() / sizeof_TC), - ranges_np_(ranges_np), print_config_(print_config) {} std::vector operator()() const { @@ -177,8 +175,7 @@ class GenerateCandidates { allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream)); const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); - kc_max = - RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); + kc_max = RoundDownWithFloor(HWY_MIN(kc_max, kMaxKC), kc_multiple_); kc_max = HWY_MIN(kc_max, K_); SizeVec all_kc(1, kc_max); @@ -258,32 +255,30 @@ class GenerateCandidates { // The number of (possibly L3 resident) B rows per `NT_MT` task. SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { - const size_t np_max = ranges_np_.TaskSize(); - size_t nc_max = np_max; + size_t nc_max = N_; // Only if there will be reuse of B: choose the largest `nc_max` (C cols) // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. // Otherwise, leave it unbounded. if (M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); - nc_max = - HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), np_max); + nc_max = HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), N_); } HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); // If there are going to be multiple ranges, anything more than half would // be imbalanced and suboptimal. - if (nc_max < np_max && nc_max >= np_max / 2) { - nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_); + if (nc_max < N_ && nc_max >= N_ / 2) { + nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_); } // Non-block calls ForNP, which ignores `range_nc` and uses `range_np`. - if (!IsBlock(order)) return SizeVec(1, np_max); + if (!IsBlock(order)) return SizeVec(1, N_); SizeVec all_nc(1, nc_max); // Avoid proposing nc > N. - if (np_max > nc_multiple_) { + if (N_ > nc_multiple_) { // Large L3, but its behavior and characteristics varies across platforms, // hence autotune a wider range of nc than the other dimensions. size_t reps = 10; @@ -292,8 +287,7 @@ class GenerateCandidates { size_t prev = nc_max; for (size_t rep = 0; rep < reps; ++rep) { - const size_t div = - PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_); + const size_t div = PrevDivisor(nc_multiple_, prev, N_, nc_multiple_); prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_); all_nc.push_back(prev); if (prev == nc_multiple_) break; @@ -346,8 +340,6 @@ class GenerateCandidates { const size_t kc_multiple_; const size_t nc_multiple_; - IndexRangePartition ranges_np_; - const bool print_config_; }; @@ -357,58 +349,19 @@ class GenerateCandidates { std::vector MMCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, size_t sizeof_TC, size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, bool print_config) { return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, - ranges_np, print_config)(); -} - -// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote -// memory accesses or false sharing, unless there are insufficient per-package -// rows for that. -static size_t NPMultiple(const Allocator& allocator, size_t N, - size_t sizeof_TC, size_t nr, size_t num_packages) { - size_t np_multiple = allocator.BasePageBytes() / sizeof_TC; - // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For - // `N` < 4096, this can cause significant load imbalance. If split unevenly, - // choose a smaller multiple. - if (N % (np_multiple * num_packages)) { - const size_t min_multiple = allocator.LineBytes() / sizeof_TC; - np_multiple = - PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple); - if (HWY_UNLIKELY(np_multiple == 0)) { - np_multiple = min_multiple; - } - // This happens in tests with small N, hence do not assert. - if (N % (np_multiple * num_packages) && N >= 128) { - static std::atomic_flag warned = ATOMIC_FLAG_INIT; - if (!warned.test_and_set()) { - HWY_WARN( - "NPMultiple: N=%zu still not divisible by np_multiple=%zu * " - "num_packages=%zu\n", - N, np_multiple, num_packages); - } - np_multiple = nr; - } - } - return np_multiple; -} - -IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, - size_t N, size_t sizeof_TC, size_t nr) { - const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages()); - return StaticPartition( - IndexRange(0, N), num_packages, - NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages)); + print_config)(); } MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { // Create storage per cluster. This only applies to in-cluster parallelism. // For nested and sequential parallelism, a single MMStorage is used. const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); + per_cluster.resize(num_clusters); storage.reserve(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - storage.push_back(MMStorage(ctx)); + storage.push_back(MMStorage(ctx.allocator)); row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } @@ -423,20 +376,15 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { PROFILER_ZONE("Startup.BindB"); - const IndexRangePartition ranges_np = - MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR); - for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { - const IndexRange& rows_b = ranges_np.Range(pkg_idx); - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - uintptr_t begin = reinterpret_cast(B.RowBytes(rows_b.begin())); - uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); - // B row padding is less than the page size, so only bind the subset that - // is page-aligned. - begin = hwy::RoundUpTo(begin, allocator.BasePageBytes()); - end = hwy::RoundDownTo(end, allocator.BasePageBytes()); - if (HWY_LIKELY(begin != end)) { - allocator.BindMemory(reinterpret_cast(begin), end - begin, node); - } + const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node(); + uintptr_t begin = reinterpret_cast(B.RowBytes(0)); + uintptr_t end = begin + B.Rows() * B.Stride() * B.ElementBytes(); + // B row padding is less than the page size, so only bind the subset that + // is page-aligned. + begin = hwy::RoundUpTo(begin, allocator.BasePageBytes()); + end = hwy::RoundDownTo(end, allocator.BasePageBytes()); + if (HWY_LIKELY(begin != end)) { + allocator.BindMemory(reinterpret_cast(begin), end - begin, node); } } @@ -447,25 +395,20 @@ void BindC(ThreadingContext& ctx, MatPtr& C) { PROFILER_ZONE("Startup.BindC"); - const IndexRangePartition ranges_np = - MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR); - bool ok = true; - for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { - const IndexRange& cols_c = ranges_np.Range(pkg_idx); - // `BindMemory` requires page alignment. These are in bytes. - const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(), - allocator.BasePageBytes()); - const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), - allocator.BasePageBytes()); + const IndexRange cols_c(0, C.Cols()); + // `BindMemory` requires page alignment. These are in bytes. + const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(), + allocator.BasePageBytes()); + const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), + allocator.BasePageBytes()); - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - for (size_t im = 0; im < C.Rows(); ++im) { - ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); - } + const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node(); + bool ok = true; + for (size_t im = 0; im < C.Rows(); ++im) { + ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); } if (HWY_UNLIKELY(!ok)) { - HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", C.Rows(), C.Cols(), - ranges_np.NumTasks()); + HWY_WARN("Failed to bind C (%zux%zu).", C.Rows(), C.Cols()); } } diff --git a/ops/matmul.h b/ops/matmul.h index 16cb51c..8c7d724 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -45,17 +45,18 @@ namespace gcpp { // This and `mr` are limited by the number of registers, which is generally // 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in // `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`. -constexpr size_t kNR = 4; +HWY_INLINE_VAR constexpr size_t kNR = 4; // Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because // we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile. // In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions // that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`, // or less on ISAs with fewer registers, or for the last few rows of A. -static constexpr size_t kMaxMR = 4; +HWY_INLINE_VAR constexpr size_t kMaxMR = 4; -IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, - size_t N, size_t sizeof_TC, size_t nr); +// Upper bound for per-worker B storage on the stack. Chosen such that one row +// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. +HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; struct MMOptions { uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. @@ -66,12 +67,12 @@ struct MMOptions { struct MMParallelNone { template - void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, - const Func& func) const { + void ForN(ThreadingContext& ctx, const IndexRange& range_n, + size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); - func(range_np, worker); + func(range_n, worker); } template @@ -102,9 +103,8 @@ struct MMParallelNone { struct MMParallelWithinCluster { template - void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, - const Func& func) const { + void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, + size_t inner_tasks, size_t cluster_idx, const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const size_t pkg_idx = 0; @@ -112,7 +112,7 @@ struct MMParallelWithinCluster { const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); const IndexRangePartition worker_ranges = StaticPartition( - range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + range_n, cluster.NumWorkers() * inner_tasks, n_multiple); ParallelizeOneRange(worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, base + worker); @@ -156,17 +156,16 @@ struct MMParallelWithinCluster { }; struct MMParallelHierarchical { - // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is + // Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. template - void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, - HWY_MAYBE_UNUSED size_t caller_cluster_idx, - const Func& func) const { + void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, + size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(caller_cluster_idx == 0); - // Single cluster: parallel-for over static partition of `range_np`. + // Single cluster: parallel-for over static partition of `range_n`. const size_t pkg_idx = 0; hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); @@ -174,7 +173,7 @@ struct MMParallelHierarchical { const size_t cluster_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const IndexRangePartition worker_ranges = StaticPartition( - range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + range_n, cluster.NumWorkers() * inner_tasks, n_multiple); return ParallelizeOneRange( worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { @@ -182,18 +181,18 @@ struct MMParallelHierarchical { }); } - // Assign each cluster a sub-range of `range_np` (typically hundreds). - const IndexRangePartition nx_ranges = - StaticPartition(range_np, num_clusters, nx_multiple); + // Assign each cluster a sub-range of `range_n` (typically hundreds). + const IndexRangePartition n_ranges = + StaticPartition(range_n, num_clusters, n_multiple); ParallelizeOneRange( - nx_ranges, all_clusters, - [&](const IndexRange& nx_range, const size_t cluster_idx) { + n_ranges, all_clusters, + [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( - nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); + n_range, cluster.NumWorkers() * inner_tasks, n_multiple); ParallelizeOneRange( worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { @@ -304,50 +303,29 @@ class StridedView { using StridedViewBF = StridedView; using StridedViewD = StridedView; -// Per-package storage for packed A. class MMStorage { public: // Compile-time bounds on matrix columns to enable pre-allocating storage // and reusing it across `MatMul` calls. static constexpr size_t kMaxK = 64 * 1024; - // Upper bound for per-worker B storage on the stack. Chosen such that one row - // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. - static constexpr size_t kMaxKC = 8 * 1024; - // Internally threaded; must not be called concurrently with the same - // `ThreadingContext` (used via `parallel`). - MMStorage(ThreadingContext& ctx) { - Allocator& allocator = ctx.allocator; - const size_t pkg_idx = 0; + MMStorage(const Allocator& allocator) + // 0.5 GiB. Must be padded, see `DoDecompressA`. + : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, + MatPadding::kOdd) {} - // 0.5 GiB per package. Must be padded, see `DoDecompressA`. - pkg_A_[pkg_idx].reset(new MatStorageT( - "pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd)); - - if (allocator.ShouldBind()) { - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * - pkg_A_[pkg_idx]->ElementBytes(); - bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); - if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { - HWY_WARN("Failed to bind memory for package %zu", pkg_idx); - } - } - } - - // Returns per-package matrix view. Converting A=F32 to BF16 up-front is - // faster than on-the-fly when native BF16 is available: it only happens once, - // not per B tile row, and the cache footprint is smaller. - StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { + // Returns matrix view. Converting A=F32 to BF16 up-front is faster than + // on-the-fly when native BF16 is available: it only happens once, not per B + // tile row, and the cache footprint is smaller. + StridedViewBF A(const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxBatchSize); HWY_DASSERT(extents.cols <= kMaxK); - HWY_DASSERT(pkg_A_[pkg_idx] != nullptr); - return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), - extents.cols, pkg_A_[pkg_idx]->Stride()); + return StridedViewBF(const_cast(A_.Row(0)), extents.cols, + A_.Stride()); } private: - std::unique_ptr> pkg_A_[kMaxPackages]; + MatStorageT A_; }; //------------------------------------------------------------------------------ @@ -433,7 +411,7 @@ class MMConfig { MMConfig() = default; // for std::vector // `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `MMOrder` is how to parallelize the outer loops. - // `inner_tasks` chooses the within-cluster task granularity in `ForNP`. + // `inner_tasks` chooses the within-cluster task granularity in `ForN`. MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, size_t kc_multiple, size_t nc_multiple, MMOrder order, int inner_tasks) @@ -470,8 +448,8 @@ class MMConfig { IndexRangePartition RangesOfKC(size_t K) const { return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_); } - IndexRangePartition RangesOfNC(IndexRange range_np) const { - return MaxSizePartition(range_np, nc_, nc_multiple_); + IndexRangePartition RangesOfNC(size_t N) const { + return MaxSizePartition(IndexRange(0, N), nc_, nc_multiple_); } MMOrder Order() const { return order_; } @@ -501,9 +479,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing std::vector MMCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, - bool print_config); + size_t max_mr, size_t nr, bool print_config); // State machine for choosing the best `TConfig`, which is `MMConfig` for the // main MatMul autotuner. @@ -609,7 +585,7 @@ class MMAutoTune { // `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range, // but choosing the same config for a larger M can result in multiple MC ranges. // Thus M less than this must have unique keys/configs. -static constexpr size_t kMaxTilesM = 8; +HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8; // Map of previously seen dimensions to index via linear search. class MMKeys { @@ -636,8 +612,8 @@ class MMKeys { return key; } - // We leave the search to callers so they can use dynamic-dispatched SIMD, - // which is not possible in this header. + // We leave the search to callers so they can use per-target SIMD, which is + // not possible in this header. hwy::Span Keys() const { return hwy::Span(keys_.get(), num_unique_); } @@ -674,26 +650,17 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N, - size_t sizeof_TC, size_t nr) - : ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) { - HWY_DASSERT(ranges_np.NumTasks() <= max_packages); - } - - // Only profile if enabled and the main autotuner finished (the par_a - // autotuner is per-package and we want to avoid synchronization). + // Only profile if enabled and the main autotuner finished. `autotune_par_a` + // might not be active if inputs are all BF16. bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); } - const IndexRangePartition ranges_np; MMAutoTune autotune; - MMAutoTune autotune_par_a[kMaxPackages]; + MMAutoTune autotune_par_a; }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive // `MatMulEnv`. struct MatMulEnv { - // Internally threaded; must not be called concurrently with the same - // `ThreadingContext`. explicit MatMulEnv(ThreadingContext& ctx); ThreadingContext& ctx; @@ -707,8 +674,13 @@ struct MatMulEnv { bool print_best = false; std::vector storage; - MMKeys keys[kMaxClusters]; - std::vector per_key[kMaxClusters]; + + struct PerCluster { + MMKeys keys; + std::vector per_key; + HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing + }; + std::vector per_cluster; // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV @@ -739,6 +711,7 @@ struct MMArgs { double scale; const float* HWY_RESTRICT add; + MMOptions options; size_t line_bytes; }; diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 665e337..373f8aa 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -275,37 +275,28 @@ void TestTiny() { if (first_target == 0) first_target = HWY_TARGET; if (HWY_TARGET != first_target) return; - for (size_t max_packages : {1, 2}) { - ThreadingArgs threading_args; - threading_args.bind = Tristate::kTrue; - threading_args.max_packages = max_packages; - ThreadingContext ctx(threading_args); - MatMulEnv env(ctx); - NestedPools& pools = env.ctx.pools; + ThreadingArgs threading_args; + threading_args.bind = Tristate::kTrue; + ThreadingContext ctx(threading_args); + MatMulEnv env(ctx); + NestedPools& pools = env.ctx.pools; - if constexpr (GEMMA_DISABLE_TOPOLOGY || kMaxPackages == 1) { - if (max_packages == 2) break; // we only have one package - } else { - // If less than the limit, we have already tested all num_packages. - if (env.ctx.topology.FullTopology().packages.size() < max_packages) break; - } - fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages, - env.ctx.topology.TopologyString(), pools.PinString()); + fprintf(stderr, "TestTiny: %s %s\n", env.ctx.topology.TopologyString(), + pools.PinString()); - pools.MaybeStartSpinning(threading_args.spin); + pools.MaybeStartSpinning(threading_args.spin); - for (size_t M = 1; M <= 12; ++M) { - for (size_t K = 1; K <= 64; K *= 2) { - for (size_t N = 4; N <= 64; N += max_packages * 4) { - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - } + for (size_t M = 1; M <= 12; ++M) { + for (size_t K = 1; K <= 64; K *= 2) { + for (size_t N = 4; N <= 64; N += 4) { + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); } } - pools.MaybeStopSpinning(threading_args.spin); } + pools.MaybeStopSpinning(threading_args.spin); } void TestAllMatMul() { diff --git a/util/allocator.cc b/util/allocator.cc index df2575e..f8bfdd5 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -160,11 +160,10 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { // - supported by the OS (currently Linux only), // - the page size is known and 'reasonably small', preferably less than // a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB. - // - we successfully detected topology and there are multiple nodes; - // - there are multiple packages, because we shard by package_idx. + // - we successfully detected topology and there are multiple nodes. if constexpr (GEMMA_BIND) { if ((base_page_bytes_ != 0 && base_page_bytes_ <= 16 * 1024) && - topology.NumNodes() > 1 && topology.NumPackages() > 1) { + topology.NumNodes() > 1) { if (enable_bind) { // Ensure pages meet the alignment requirements of `AllocBytes`. HWY_ASSERT(base_page_bytes_ >= quantum_bytes_); diff --git a/util/allocator.h b/util/allocator.h index bf904c5..42e261c 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -149,7 +149,7 @@ class Allocator { } // Returns whether `BindMemory` can/should be called, i.e. we have page-level - // control over memory placement and multiple packages and NUMA nodes. + // control over memory placement and multiple NUMA nodes. bool ShouldBind() const { return should_bind_; } // Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is diff --git a/util/basics.h b/util/basics.h index 7cdc17c..c8858e5 100644 --- a/util/basics.h +++ b/util/basics.h @@ -30,13 +30,6 @@ namespace gcpp { -// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the -// runtime `max_packages` does not exceed this. MatMul's outer per-package loop -// is disabled if this is 1. -HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; - -HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink - // TODO: extend to 16k after updating non_eos. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; diff --git a/util/mat.h b/util/mat.h index b0de72d..9d838e2 100644 --- a/util/mat.h +++ b/util/mat.h @@ -455,6 +455,7 @@ class MatOwner { template class MatStorageT : public MatPtrT { public: + MatStorageT() = default; // for std::vector in Activations. MatStorageT(const char* name, Extents2D extents, const Allocator& allocator, MatPadding padding) : MatPtrT(name, extents) { diff --git a/util/threading_context.h b/util/threading_context.h index 847ce81..6bd6936 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -25,7 +25,7 @@ // IWYU pragma: begin_exports #include "util/allocator.h" #include "util/args.h" -#include "util/basics.h" // Tristate, kMaxPackages +#include "util/basics.h" // Tristate #include "util/threading.h" #include "util/topology.h" #include "hwy/profiler.h" @@ -41,7 +41,7 @@ class ThreadingArgs : public ArgsBase { // For BoundedTopology: size_t skip_packages; - size_t max_packages; + size_t max_packages = 1; size_t skip_clusters; size_t max_clusters; size_t skip_lps; @@ -58,13 +58,9 @@ class ThreadingArgs : public ArgsBase { void ForEach(const Visitor& visitor) { // These can be used to partition CPU packages/sockets and their // clusters/CCXs across several program instances. The default is to use - // all available resources on one package. Note that `kMaxPackages` is an - // upper bound on `max_packages`. + // all available resources on the first package. visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); - visitor(max_packages, "max_packages", size_t{1}, - "Max sockets to use; default = 1, 0 = unlimited.", 2); - HWY_ASSERT(max_packages <= kMaxPackages); visitor(skip_clusters, "skip_clusters", size_t{0}, "Index of the first CCX to use; default 0 = unlimited.", 2); visitor(max_clusters, "max_clusters", size_t{0}, @@ -105,7 +101,7 @@ struct ThreadingContext { hwy::Profiler& profiler; // Detects topology, subject to limits imposed by user-specified `args`. - // For example, if `args.max_packages` is 1, then `topology.NumPackages()` + // For example, if `args.max_clusters` is 1, then `topology.NumClusters()` // will be 1 regardless of the actual system topology. BoundedTopology topology; From afd82376a5c11b70bd852ceb2c4b41ca39d24518 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Sep 2025 05:58:08 -0700 Subject: [PATCH 26/65] Add AES-CTR RNG for parallel sampling (not yet used) PiperOrigin-RevId: 802991142 --- BUILD.bazel | 15 ++++++ CMakeLists.txt | 2 + util/basics.cc | 75 ++++++++++++++++++++++++++++++ util/basics.h | 36 +++++++++++++++ util/basics_test.cc | 108 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 236 insertions(+) create mode 100644 util/basics.cc create mode 100644 util/basics_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index ce4cffb..62f2f5c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -29,9 +29,24 @@ exports_files([ cc_library( name = "basics", + srcs = ["util/basics.cc"], hdrs = ["util/basics.h"], deps = [ "@highway//:hwy", + "@highway//:timer", + "@highway//hwy/contrib/sort:vqsort", + ], +) + +cc_test( + name = "basics_test", + srcs = ["util/basics_test.cc"], + deps = [ + ":basics", + "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:timer", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8309840..d3a66fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,6 +120,7 @@ set(SOURCES paligemma/image.h util/allocator.cc util/allocator.h + util/basics.cc util/basics.h util/mat.cc util/mat.h @@ -227,6 +228,7 @@ set(GEMMA_TEST_FILES ops/ops_test.cc paligemma/image_test.cc paligemma/paligemma_test.cc + util/basics_test.cc util/threading_test.cc ) diff --git a/util/basics.cc b/util/basics.cc new file mode 100644 index 0000000..4261510 --- /dev/null +++ b/util/basics.cc @@ -0,0 +1,75 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/basics.h" + +#include +#include + +#include "hwy/contrib/sort/vqsort.h" +#include "hwy/highway.h" +#include "hwy/timer.h" + +namespace gcpp { + +RNG::RNG(bool deterministic) { + // Pi-based nothing up my sleeve numbers from Randen. + key_[0] = 0x243F6A8885A308D3ull; + key_[1] = 0x13198A2E03707344ull; + + if (!deterministic) { // want random seed + if (!hwy::Fill16BytesSecure(key_)) { + HWY_WARN("Failed to fill RNG key with secure random bits"); + // Entropy not available. The test requires that we inject some + // differences relative to the deterministic seeds. + key_[0] ^= reinterpret_cast(this); + key_[1] ^= hwy::timer::Start(); + } + } + + // Simple key schedule: swap and add constant (also from Randen). + for (size_t i = 0; i < kRounds; ++i) { + key_[2 + 2 * i + 0] = key_[2 * i + 1] + 0xA4093822299F31D0ull; + key_[2 + 2 * i + 1] = key_[2 * i + 0] + 0x082EFA98EC4E6C89ull; + } +} + +namespace hn = hwy::HWY_NAMESPACE; +using D = hn::Full128; // 128 bits for AES +using V = hn::Vec; + +static V Load(const uint64_t* ptr) { + return hn::Load(D(), reinterpret_cast(ptr)); +} + +RNG::result_type RNG::operator()() { + V state = Load(counter_); + counter_[0]++; + state = hn::Xor(state, Load(key_)); // initial whitening + + static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t)); + state = hn::AESRound(state, Load(key_ + 2)); + state = hn::AESRound(state, Load(key_ + 4)); + state = hn::AESRound(state, Load(key_ + 6)); + state = hn::AESRound(state, Load(key_ + 8)); + // Final round: fine to use another AESRound, including MixColumns. + state = hn::AESRound(state, Load(key_ + 10)); + + // Return lower 64 bits of the u8 vector. + const hn::Repartition d64; + return hn::GetLane(hn::BitCast(d64, state)); +} + +} // namespace gcpp diff --git a/util/basics.h b/util/basics.h index c8858e5..2429c72 100644 --- a/util/basics.h +++ b/util/basics.h @@ -119,6 +119,42 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end, size_t max_size) { return IndexRange(begin, HWY_MIN(begin + max_size, end)); } + +// Non-cryptographic 64-bit pseudo-random number generator. Supports random or +// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`. +// +// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This +// is useful for parallel sampling. Each thread can generate the stream for a +// particular task, without caring about prior/subsequent generations. +class alignas(16) RNG { + // "Large-scale randomness study of security margins for 100+ cryptographic + // functions": at least four. + // "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant. + static constexpr size_t kRounds = 5; + + public: + explicit RNG(bool deterministic); + + void SetStream(uint64_t stream) { + counter_[1] = stream; + counter_[0] = 0; + } + + using result_type = uint64_t; + static constexpr result_type min() { return 0; } + static constexpr result_type max() { return ~result_type{0}; } + + // About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via + // unrolling by the AES latency (4-7 cycles). `std::discrete_distribution` + // makes individual calls to the generator, which would require buffering, + // which is not worth the complexity. + result_type operator()(); + + private: + uint64_t counter_[2] = {}; + uint64_t key_[2 * (1 + kRounds)]; +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ diff --git a/util/basics_test.cc b/util/basics_test.cc new file mode 100644 index 0000000..169d051 --- /dev/null +++ b/util/basics_test.cc @@ -0,0 +1,108 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/basics.h" + +#include +#include + +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/timer.h" + +namespace gcpp { +namespace { + +TEST(BasicsTest, IsDeterministic) { + RNG rng1(/*deterministic=*/true); + RNG rng2(/*deterministic=*/true); + // Remember for later testing after resetting the stream. + const uint64_t r0 = rng1(); + const uint64_t r1 = rng1(); + // Not consecutive values. This could actually happen due to the extra XOR, + // but given the deterministic seeding here, we know it will not. + HWY_ASSERT(r0 != r1); + // Let rng2 catch up. + HWY_ASSERT(r0 == rng2()); + HWY_ASSERT(r1 == rng2()); + + for (size_t i = 0; i < 1000; ++i) { + HWY_ASSERT(rng1() == rng2()); + } + + // Reset counter, ensure it matches the default-constructed RNG. + rng1.SetStream(0); + HWY_ASSERT(r0 == rng1()); + HWY_ASSERT(r1 == rng1()); +} + +TEST(BasicsTest, IsSeeded) { + RNG rng1(/*deterministic=*/true); + RNG rng2(/*deterministic=*/false); + // It would be very unlucky to have even one 64-bit value match, and two are + // extremely unlikely. + const uint64_t a0 = rng1(); + const uint64_t a1 = rng1(); + const uint64_t b0 = rng2(); + const uint64_t b1 = rng2(); + HWY_ASSERT(a0 != b0 || a1 != b1); +} + +// If not close to 50% 1-bits, the RNG is quite broken. +TEST(BasicsTest, BitDistribution) { + RNG rng(/*deterministic=*/true); + constexpr size_t kU64 = 2 * 1000 * 1000; + const hwy::Timestamp t0; + uint64_t one_bits = 0; + for (size_t i = 0; i < kU64; ++i) { + one_bits += hwy::PopCount(rng()); + } + const uint64_t total_bits = kU64 * 64; + const double one_ratio = static_cast(one_bits) / total_bits; + const double elapsed = hwy::SecondsSince(t0); + fprintf(stderr, "1-bit ratio %.5f, %.1f M/s\n", one_ratio, + kU64 / elapsed * 1E-6); + HWY_ASSERT(0.4999 <= one_ratio && one_ratio <= 0.5001); +} + +TEST(BasicsTest, ChiSquared) { + RNG rng(/*deterministic=*/true); + constexpr size_t kU64 = 1 * 1000 * 1000; + + // Test each byte separately. + for (size_t shift = 0; shift < 64; shift += 8) { + size_t counts[256] = {}; + for (size_t i = 0; i < kU64; ++i) { + const size_t byte = (rng() >> shift) & 0xFF; + counts[byte]++; + } + + double chi_squared = 0.0; + const double expected = static_cast(kU64) / 256.0; + for (size_t i = 0; i < 256; ++i) { + const double diff = static_cast(counts[i]) - expected; + chi_squared += diff * diff / expected; + } + // Should be within ~0.5% and 99.5% percentiles. See + // https://www.medcalc.org/manual/chi-square-table.php + if (chi_squared < 196.0 || chi_squared > 311.0) { + HWY_ABORT("Chi-squared byte %zu: %.5f \n", shift / 8, chi_squared); + } + } +} + +} // namespace +} // namespace gcpp +HWY_TEST_MAIN(); From 5d1693e806c04d078dbc7425e4542a15340067a6 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Sep 2025 10:30:42 -0700 Subject: [PATCH 27/65] Internal change PiperOrigin-RevId: 803083229 --- BUILD.bazel | 4 ++-- gemma/configs.h | 2 +- gemma/gemma.cc | 15 ++++++--------- gemma/weights.cc | 17 +++++++++-------- python/configs.cc | 3 +-- 5 files changed, 19 insertions(+), 22 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 62f2f5c..cbfb342 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -539,14 +539,14 @@ cc_library( ":weights", "//compression:compress", "//compression:types", - "//io:blob_store", "//io", + "//io:blob_store", "//paligemma:image", "@highway//:hwy", - "@highway//hwy/contrib/sort:vqsort", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", + "@highway//hwy/contrib/sort:vqsort", ], ) diff --git a/gemma/configs.h b/gemma/configs.h index a1cd902..0c93e30 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -421,7 +421,7 @@ struct ModelConfig : public IFields { } size_t KVCacheCols() const { - size_t num_layers = layer_configs.size(); + const size_t num_layers = layer_configs.size(); return num_layers * layer_configs[0].CacheLayerSize(); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a0949fe..a7e73ca 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -556,16 +556,13 @@ static void GenerateT(const ModelConfig& config, const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); - { - timing_info.generate_start = hwy::platform::Now(); - for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { - Transformer(config, runtime_config, weights, activations, qbatch, env); - SampleAndStream(config, runtime_config, weights, sample_token, - activations, qbatch, /*update_pos=*/true, env, non_eos, - timing_info); - } - timing_info.NotifyGenerateDone(); + timing_info.generate_start = hwy::platform::Now(); + for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { + Transformer(config, runtime_config, weights, activations, qbatch, env); + SampleAndStream(config, runtime_config, weights, sample_token, activations, + qbatch, /*update_pos=*/true, env, non_eos, timing_info); } + timing_info.NotifyGenerateDone(); } void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, diff --git a/gemma/weights.cc b/gemma/weights.cc index ca1cebc..3d1d43e 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -226,15 +226,16 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { // ideally already happen in the importer. Called by `ReadFromBlobs`. void WeightsPtrs::Fixup(std::vector& mat_owners, ThreadingContext& ctx) { - // TODO: use 1D parallel-for helper function - hwy::ThreadPool& pool = ctx.pools.Pool(); - pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Fixup(mat_owners, ctx.allocator); - }); + const size_t cluster_idx = 0; + ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, + [&](uint64_t layer, size_t /*worker*/) { + GetLayer(layer)->Fixup(mat_owners, ctx.allocator); + }); - pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - VitLayer(layer)->Fixup(mat_owners, ctx.allocator); - }); + ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, + [&](uint64_t layer, size_t /*worker*/) { + VitLayer(layer)->Fixup(mat_owners, ctx.allocator); + }); } std::vector WeightsPtrs::AddTensorDataToWriter( diff --git a/python/configs.cc b/python/configs.cc index 36cd314..f8121bf 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -147,8 +147,7 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("image_size", &VitConfig::image_size) .def_readwrite("layer_configs", &VitConfig::layer_configs); - class_(py_module, "InternalModelConfig") - .def(init<>()); + class_(py_module, "InternalModelConfig").def(init<>()); class_(py_module, "ModelConfig") .def(init<>()) From 56186193c1f6d3dfa13bf642350a690166d1b282 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Sep 2025 23:48:37 -0700 Subject: [PATCH 28/65] Replace mt19937 with new generator to enable parallel sampling Split it into immutable AesCtrEngine and RngStream Also add RowSpan and Logits span PiperOrigin-RevId: 803336423 --- compression/nuq_test.cc | 9 +- evals/benchmark_helper.cc | 15 -- evals/benchmark_helper.h | 5 - evals/cross_entropy.cc | 25 ++-- evals/gemma_test.cc | 1 - evals/run_mmlu.cc | 1 - examples/hello_world/run.cc | 10 +- examples/simplified_gemma/gemma.hpp | 9 +- gemma/api_server.cc | 203 ++++++++++++++-------------- gemma/attention.cc | 5 +- gemma/bindings/context.cc | 7 +- gemma/bindings/context.h | 8 -- gemma/gemma.cc | 52 +++---- gemma/gemma.h | 1 + gemma/gemma_args.h | 12 +- gemma/run.cc | 11 +- gemma/vit.cc | 5 +- ops/dot_test.cc | 33 ++--- ops/ops-inl.h | 150 ++++++++++---------- ops/ops_test.cc | 35 ++--- paligemma/paligemma_helper.cc | 48 +++---- python/gemma_py.cc | 18 +-- util/basics.cc | 10 +- util/basics.h | 53 +++++--- util/basics_test.cc | 43 ++++-- util/mat.h | 7 + 26 files changed, 382 insertions(+), 394 deletions(-) diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index df300f4..7ddee9e 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -24,7 +24,6 @@ #include // std::shuffle #include -#include #include "compression/distortion.h" #include "util/test_util.h" @@ -104,8 +103,8 @@ struct TestPlateaus { HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f); } - std::random_device rd; // NOLINT - std::mt19937 rng(rd()); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); std::shuffle(in.get(), in.get() + kGroupSize, rng); NuqStream::ClusterBuf buf; @@ -151,8 +150,8 @@ struct TestRamp { HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f); } - std::random_device rd; // NOLINT - std::mt19937 rng(rd()); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); std::shuffle(in.get(), in.get() + kGroupSize, rng); NuqStream::ClusterBuf buf; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 3b999b4..55e99cf 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -20,7 +20,6 @@ #include #include -#include #include #include @@ -37,17 +36,6 @@ namespace gcpp { -void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { - if (inference.deterministic) { - // Nothing up my sleeve number, at least some upper bits set. - gen.seed(0x12345678); - } else { - // Depending on the library implementation, this may still be deterministic. - std::random_device rd; // NOLINT - gen.seed(rd()); - } -} - GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { @@ -60,12 +48,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, ctx_); } - InitGenerator(inference, gen_); - runtime_config_ = { .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, - .gen = &gen_, .verbosity = inference.verbosity, }; inference.CopyTo(runtime_config_); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 8f4d96f..261daa4 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -18,7 +18,6 @@ #include -#include #include #include @@ -32,8 +31,6 @@ namespace gcpp { -void InitGenerator(const InferenceArgs& inference, std::mt19937& gen); - // Return type for query model calls. struct QueryResult { std::string response; @@ -107,7 +104,6 @@ class GemmaEnv { int Verbosity() const { return runtime_config_.verbosity; } RuntimeConfig& MutableConfig() { return runtime_config_; } - std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return kv_caches_[0]; } MatMulEnv& MutableEnv() { return env_; } @@ -115,7 +111,6 @@ class GemmaEnv { ThreadingContext ctx_; MatMulEnv env_; Gemma gemma_; - std::mt19937 gen_; // Random number generator. std::vector kv_caches_; // Same number as query batch. RuntimeConfig runtime_config_; }; diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index c150041..b7abb10 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -56,11 +56,10 @@ static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'"; } -void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len, - size_t k) { - std::vector> sorted(len); - for (size_t i = 0; i < len; ++i) { - sorted[i] = std::make_pair(dist[i], static_cast(i)); +void LogTopK(const GemmaTokenizer& tokenizer, Logits logits, size_t k) { + std::vector> sorted(logits.size()); + for (size_t i = 0; i < logits.size(); ++i) { + sorted[i] = std::make_pair(logits[i], static_cast(i)); } std::sort(sorted.begin(), sorted.end(), [](const std::pair& a, const std::pair& b) { @@ -84,9 +83,8 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size, - hwy::Profiler& p) { - Softmax(logits, vocab_size, p, hwy::Profiler::Thread()); +void CallSoftmax(Logits logits, hwy::Profiler& p) { + Softmax(logits, p, hwy::Profiler::Thread()); } } // namespace HWY_NAMESPACE @@ -107,19 +105,19 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) size_t pos = 1; - const SampleFunc sample_token = [&](float* probs, - size_t vocab_size) -> TokenAndProb { + const SampleFunc sample_token = [&](size_t qi, + Logits logits) -> TokenAndProb { // input is logits, not yet probabilities - HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler); + HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); // We are called for each token, but pos starts at 1. Clamping // max_generated_tokens to prompt.size() should prevent overrun. HWY_ASSERT(pos < prompt.size()); const int token = prompt[pos]; - const float prob = probs[token]; + const float prob = logits[token]; cross_entropy -= std::max(std::log(prob), -64.0f); if (verbosity >= 4) { - LogTopK(gemma.Tokenizer(), probs, vocab_size, 10); + LogTopK(gemma.Tokenizer(), logits, 10); } if (verbosity >= 3) { printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token, @@ -139,7 +137,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, RuntimeConfig runtime = { .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, - .gen = nullptr, .verbosity = verbosity, .stream_token = stream_token, .sample_func = sample_token, diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 77efbae..26313c1 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -115,7 +115,6 @@ TEST_F(GemmaTest, Multiturn) { RuntimeConfig runtime_config{ .max_generated_tokens = 64, .temperature = 0.0f, - .gen = &s_env->MutableGen(), .verbosity = 2, .batch_stream_token = stream_token, }; diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index b6537fe..04a6e00 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -126,7 +126,6 @@ void Run(GemmaEnv& env, JsonArgs& json) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 30, .temperature = 0.0f, - .gen = &env.MutableGen(), .verbosity = env.Verbosity(), .stream_token = stream_token, }; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 193903f..f67324d 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,8 +17,8 @@ #include #include +#include #include -#include #include #include #include @@ -44,7 +44,7 @@ int main(int argc, char** argv) { for (int arg = 0; arg < argc; ++arg) { // Find a --reject flag and consume everything after it. if (strcmp(argv[arg], "--reject") == 0) { - while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); + while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); // NOLINT } } @@ -55,11 +55,6 @@ int main(int argc, char** argv) { gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); size_t generated = 0; - // Initialize random number generator - std::mt19937 gen; - std::random_device rd; // NOLINT - gen.seed(rd()); - // Tokenize instructions. std::string prompt = "Write a greeting to the world."; const std::vector tokens = @@ -84,7 +79,6 @@ int main(int argc, char** argv) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 1024, .temperature = 1.0, - .gen = &gen, .verbosity = 0, .stream_token = stream_token, .accept_token = diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 7800233..e5bb1d8 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -38,11 +37,7 @@ class SimplifiedGemma { : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_), - kv_cache_(gemma_.Config(), inference, ctx_.allocator) { - // Initialize random number generator - std::random_device rd; - gen_.seed(rd()); - } + kv_cache_(gemma_.Config(), inference, ctx_.allocator) {} SimplifiedGemma(int argc, char** argv) : SimplifiedGemma(gcpp::LoaderArgs(argc, argv), @@ -76,7 +71,6 @@ class SimplifiedGemma { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = max_generated_tokens, .temperature = temperature, - .gen = &gen_, .verbosity = 0, .stream_token = stream_token, .accept_token = @@ -93,6 +87,5 @@ class SimplifiedGemma { gcpp::MatMulEnv env_; gcpp::Gemma gemma_; gcpp::KVCache kv_cache_; - std::mt19937 gen_; std::string validation_error_; }; diff --git a/gemma/api_server.cc b/gemma/api_server.cc index 70b3115..ea5377d 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -60,18 +60,18 @@ struct ServerState { std::unique_ptr gemma; MatMulEnv* env; ThreadingContext* ctx; - + // Session-based KV cache storage struct Session { std::unique_ptr kv_cache; size_t abs_pos = 0; std::chrono::steady_clock::time_point last_access; }; - + std::unordered_map sessions; std::mutex sessions_mutex; std::mutex inference_mutex; - + // Cleanup old sessions after 30 minutes of inactivity void CleanupOldSessions() { std::lock_guard lock(sessions_mutex); @@ -84,7 +84,7 @@ struct ServerState { } } } - + // Get or create session with KV cache Session& GetOrCreateSession(const std::string& session_id) { std::lock_guard lock(sessions_mutex); @@ -101,24 +101,25 @@ struct ServerState { std::string GenerateSessionId() { static std::atomic counter{0}; std::stringstream ss; - ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count() - << "_" << counter.fetch_add(1); + ss << "session_" << std::hex + << std::chrono::steady_clock::now().time_since_epoch().count() << "_" + << counter.fetch_add(1); return ss.str(); } // Wraps messages with start_of_turn markers - handles both with and without roles std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string prompt; - + for (const auto& content : contents) { if (content.contains("parts")) { // Check if role is specified (public API format) or not (local format) std::string role = content.value("role", ""); - + for (const auto& part : content["parts"]) { if (part.contains("text")) { std::string text = part["text"]; - + if (role == "user") { prompt += "user\n" + text + "\nmodel\n"; } else if (role == "model") { @@ -131,24 +132,23 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) { } } } - + return prompt; } // Parse generation config -RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) { +RuntimeConfig ParseGenerationConfig(const json& request) { RuntimeConfig config; - config.gen = &gen; config.verbosity = 0; - + // Set defaults matching public API config.temperature = 1.0f; config.top_k = 1; config.max_generated_tokens = 8192; - + if (request.contains("generationConfig")) { auto& gen_config = request["generationConfig"]; - + if (gen_config.contains("temperature")) { config.temperature = gen_config["temperature"].get(); } @@ -159,7 +159,7 @@ RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) { config.max_generated_tokens = gen_config["maxOutputTokens"].get(); } } - + return config; } @@ -175,12 +175,12 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) }}}, {"promptFeedback", {{"safetyRatings", json::array()}}} }; - + // Only add finishReason for non-streaming chunks if (!is_streaming_chunk) { response["candidates"][0]["finishReason"] = "STOP"; } - + return response; } @@ -188,11 +188,11 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { try { json request = json::parse(req.body); - + // Get or create session std::string session_id = request.value("sessionId", GenerateSessionId()); auto& session = state.GetOrCreateSession(session_id); - + // Extract prompt from API format std::string prompt; if (request.contains("contents")) { @@ -202,32 +202,29 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); return; } - + // Lock for inference std::lock_guard lock(state.inference_mutex); - + // Set up runtime config - std::mt19937 gen; - RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); - + RuntimeConfig runtime_config = ParseGenerationConfig(request); + // Collect full response std::string full_response; runtime_config.stream_token = [&full_response](int token, float) { // Skip EOS token return true; }; - + // Tokenize prompt - std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), - state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, - session.abs_pos, - prompt); - + std::vector tokens = WrapAndTokenize( + state.gemma->Tokenizer(), state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, session.abs_pos, prompt); + // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; - + // Temporarily redirect output to capture response std::stringstream output; runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { @@ -236,25 +233,25 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques session.abs_pos++; return true; } - + session.abs_pos++; - + // Check for EOS if (state.gemma->Config().IsEOS(token)) { return true; } - + // Decode token std::string token_text; state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); output << token_text; - + return true; }; - - state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, - *session.kv_cache, *state.env, timing_info); - + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, + *session.kv_cache, *state.env, timing_info); + // Create response json response = CreateAPIResponse(output.str(), false); response["usageMetadata"] = { @@ -262,17 +259,22 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques {"candidatesTokenCount", session.abs_pos - tokens.size()}, {"totalTokenCount", session.abs_pos} }; - + res.set_content(response.dump(), "application/json"); - + } catch (const json::exception& e) { res.status = 400; - res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), - "application/json"); + res.set_content( + json{{"error", + {{"message", std::string("JSON parsing error: ") + e.what()}}}} + .dump(), + "application/json"); } catch (const std::exception& e) { res.status = 500; - res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(), - "application/json"); + res.set_content( + json{{"error", {{"message", std::string("Server error: ") + e.what()}}}} + .dump(), + "application/json"); } } @@ -280,11 +282,11 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { try { json request = json::parse(req.body); - + // Get or create session std::string session_id = request.value("sessionId", GenerateSessionId()); auto& session = state.GetOrCreateSession(session_id); - + // Extract prompt from API format std::string prompt; if (request.contains("contents")) { @@ -294,13 +296,13 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); return; } - + // Set up SSE headers res.set_header("Content-Type", "text/event-stream"); res.set_header("Cache-Control", "no-cache"); res.set_header("Connection", "keep-alive"); res.set_header("X-Session-Id", session_id); - + // Set up chunked content provider for SSE res.set_chunked_content_provider( "text/event-stream", @@ -309,18 +311,15 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& // Lock for inference std::lock_guard lock(state.inference_mutex); auto& session = state.GetOrCreateSession(session_id); - + // Set up runtime config - std::mt19937 gen; - RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); - + RuntimeConfig runtime_config = ParseGenerationConfig(request); + // Tokenize prompt - std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), - state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, - session.abs_pos, - prompt); - + std::vector tokens = WrapAndTokenize( + state.gemma->Tokenizer(), state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, session.abs_pos, prompt); + // Stream token callback std::string accumulated_text; auto stream_token = [&](int token, float) { @@ -329,37 +328,38 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& session.abs_pos++; return true; } - + session.abs_pos++; - + // Check for EOS if (state.gemma->Config().IsEOS(token)) { return true; } - + // Decode token std::string token_text; state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); accumulated_text += token_text; - + // Send SSE event using unified formatter json event = CreateAPIResponse(token_text, true); - + std::string sse_data = "data: " + event.dump() + "\n\n"; sink.write(sse_data.data(), sse_data.size()); - + return true; }; - + runtime_config.stream_token = stream_token; - + // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; - - state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, - *session.kv_cache, *state.env, timing_info); - + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, + prefix_end, *session.kv_cache, *state.env, + timing_info); + // Send final event using unified formatter json final_event = CreateAPIResponse("", false); final_event["usageMetadata"] = { @@ -367,18 +367,18 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& {"candidatesTokenCount", session.abs_pos - tokens.size()}, {"totalTokenCount", session.abs_pos} }; - + std::string final_sse = "data: " + final_event.dump() + "\n\n"; sink.write(final_sse.data(), final_sse.size()); - + // Send done event sink.write("data: [DONE]\n\n", 15); - + // Ensure all data is sent sink.done(); - + return false; // End streaming - + } catch (const std::exception& e) { json error_event = {{"error", {{"message", e.what()}}}}; std::string error_sse = "data: " + error_event.dump() + "\n\n"; @@ -387,11 +387,14 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& } } ); - + } catch (const json::exception& e) { res.status = 400; - res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), - "application/json"); + res.set_content( + json{{"error", + {{"message", std::string("JSON parsing error: ") + e.what()}}}} + .dump(), + "application/json"); } } @@ -410,7 +413,7 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const {"topK", 1} }}} }; - + res.set_content(response.dump(), "application/json"); } @@ -419,40 +422,40 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const // server_running = false; // } -void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, +void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) { std::cerr << "Loading model..." << std::endl; - + // Initialize model ThreadingContext ctx(threading); MatMulEnv env(ctx); - + ServerState state; state.gemma = std::make_unique(loader, inference, ctx); state.env = &env; state.ctx = &ctx; - + httplib::Server server; - + // Set up routes server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); }); - + // API endpoints server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { HandleListModels(state, inference, req, res); }); - + std::string model_endpoint = "/v1beta/models/" + inference.model; server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { HandleGenerateContentNonStreaming(state, req, res); }); - + server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { HandleGenerateContentStreaming(state, req, res); }); - + // Periodic cleanup of old sessions std::thread cleanup_thread([&state]() { while (server_running) { @@ -460,18 +463,18 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, state.CleanupOldSessions(); } }); - + std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Endpoints:" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; std::cerr << " GET /v1beta/models" << std::endl; - + if (!server.listen("0.0.0.0", inference.port)) { std::cerr << "Failed to start server on port " << inference.port << std::endl; } - + cleanup_thread.join(); } @@ -479,11 +482,11 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, int main(int argc, char** argv) { gcpp::InternalInit(); - + gcpp::LoaderArgs loader(argc, argv); gcpp::ThreadingArgs threading(argc, argv); gcpp::InferenceArgs inference(argc, argv); - + if (gcpp::HasHelp(argc, argv)) { std::cerr << "\n\nAPI server for gemma.cpp\n"; std::cout << "========================\n\n"; @@ -501,14 +504,14 @@ int main(int argc, char** argv) { std::cerr << "\n"; return 0; } - + // Arguments are now handled by InferenceArgs - + // // Set up signal handler // signal(SIGINT, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown); - + gcpp::RunServer(loader, threading, inference); - + return 0; } diff --git a/gemma/attention.cc b/gemma/attention.cc index 21e5019..8afd561 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -155,8 +155,9 @@ void SingleDotSoftmaxWeightedSum( // SoftMax with optional SoftCap yields "probabilities" in att. const size_t att_len = HWY_MIN(last_pos + 1, seq_len); - MaybeLogitsSoftCap(att_cap, att, att_len, p, worker); - Softmax(att, att_len, p, worker, /*temperature=*/1.0f); + const Logits logits(att, att_len); + MaybeLogitsSoftCap(att_cap, logits, p, worker); + Softmax(logits, p, worker, /*temperature=*/1.0f); WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p, worker); diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index a6ebe30..5741d70 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -23,7 +23,6 @@ #include #include -#include "evals/benchmark_helper.h" // InitGenerator #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "gemma/tokenizer.h" // WrapAndTokenize @@ -135,8 +134,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string, std::stringstream ss; result_buffer.clear(); - InitGenerator(inference_args, gen); - // Ensure we have an active conversation if (!active_conversation || !active_conversation->kv_cache) { LogDebug("Generate called with null active_conversation or kv_cache"); @@ -174,8 +171,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // set up runtime config TimingInfo timing_info = {}; - RuntimeConfig runtime_config = {.gen = &gen, - .stream_token = stream_token, + RuntimeConfig runtime_config = {.stream_token = stream_token, .use_spinning = threading_args.spin}; inference_args.CopyTo(runtime_config); size_t prefix_end = 0; @@ -256,7 +252,6 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // If not multiturn, or Paligemma (which handles turns differently), // reset the *active* conversation's position. active_conversation->abs_pos = 0; - InitGenerator(inference_args, gen); } else { // Multi-turn Gemma: Rewind position in the active conversation // The last token was either EOS, then it should be ignored because it is diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index 859a644..00648fc 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -17,7 +17,6 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ #include // For std::shared_ptr, std::make_shared -#include #include #include #include @@ -107,10 +106,6 @@ class GemmaContext { // Set deterministic flag void SetDeterministic(bool value) { inference_args.deterministic = value; - // Reset the random number generator for deterministic generation - if (value) { - gen.seed(0x87654321); - } LogDebug("Setting deterministic flag to configured value"); } @@ -289,9 +284,6 @@ class GemmaContext { // Model itself (don't move this, needs to be below the args above) Gemma model; - // Random generator (remains global for the context) - std::mt19937 gen; - // Static members for logging static GemmaLogCallback s_log_callback; static void* s_log_user_data; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a7e73ca..0177c92 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -440,8 +440,7 @@ static void SampleAndStream( // TODO: parallelize non_eos.Foreach([&](size_t qi) { - float* HWY_RESTRICT logits = activations.logits.Row(qi); - const TokenAndProb tp = sample_token(logits, config.vocab_size); + const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi)); // We streamed all prefill tokens, but pos is still one behind because we // started generation at pos = prompt.size() - 1. We want the pos argument @@ -453,7 +452,8 @@ static void SampleAndStream( } static HWY_INLINE SampleFunc -ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) { +ChooseSampleFunc(const RuntimeConfig& runtime_config, + const AesCtrEngine& engine, ThreadingContext& ctx) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; @@ -462,27 +462,28 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) { // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb { PROFILER_ZONE3(ctx.profiler, worker, zone); - return Top1OfSoftmax(logits, vocab_size); + return Top1OfSoftmax(logits); }; } // General case: Softmax with top-k sampling. - return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); + RngStream gen(engine, qi); return FusedSoftmaxAndSampleTopK( - logits, runtime_config.top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token, ctx.profiler, - worker); + logits, runtime_config.top_k, gen, runtime_config.temperature, + runtime_config.accept_token, ctx.profiler, worker); }; } // Decode: generates one continuation token for each query in `qbatch`. static void GenerateT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, Activations& activations, - QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { + const AesCtrEngine& engine, const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, MatMulEnv& env, + TimingInfo& timing_info) { // Griffin assumes that the recurrent block cache is zero-initialized. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { if (qbatch.MutablePos(qi) == 0) { @@ -554,7 +555,8 @@ static void GenerateT(const ModelConfig& config, max_gen_steps = seq_len - max_prompt_size; } - const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); + const SampleFunc sample_token = + ChooseSampleFunc(runtime_config, engine, env.ctx); timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { @@ -568,15 +570,16 @@ static void GenerateT(const ModelConfig& config, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, KVCache& kv_cache, - MatMulEnv& env, TimingInfo& timing_info) { + const AesCtrEngine& engine, const WeightsPtrs& weights, + KVCache& kv_cache, MatMulEnv& env, + TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries); - GenerateT(config, runtime_config, weights, activations, qbatch, env, + GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, timing_info); } @@ -584,8 +587,9 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, // queries, and calls `GenerateT` on each batch. void GenerateBatchT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, AllQueries& all_queries, - MatMulEnv& env, TimingInfo& timing_info) { + const AesCtrEngine& engine, const WeightsPtrs& weights, + AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) { const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); Activations activations(config, max_batch_size, @@ -596,7 +600,7 @@ void GenerateBatchT(const ModelConfig& config, start += runtime_config.decode_qbatch_size) { QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); // Generate a batch of one token for each of `qbatch.Size()` queries. - GenerateT(config, runtime_config, weights, activations, qbatch, env, + GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, timing_info); } } @@ -637,7 +641,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), - inference_(inference) { + inference_(inference), + aes_ctr_engine_(inference.deterministic) { // Negligible CPU time in the ctor body (except ReadFromBlobs). weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx); @@ -661,9 +666,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, TimingInfo& timing_info) const { env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end, - model_.Config(), runtime_config, - weights_, kv_cache, env, timing_info); + HWY_DYNAMIC_DISPATCH(GenerateSingleT)( + prompt, pos, prefix_end, model_.Config(), runtime_config, aes_ctr_engine_, + weights_, kv_cache, env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -674,7 +679,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config, - weights_, all_queries, env, timing_info); + aes_ctr_engine_, weights_, all_queries, + env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 0f9aae2..2f06ab8 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -278,6 +278,7 @@ class Gemma { WeightsPtrs::Mode weight_read_mode_; GemmaChatTemplate chat_template_; InferenceArgs inference_; + AesCtrEngine aes_ctr_engine_; }; } // namespace gcpp diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 2a49349..59e3a6c 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -22,7 +22,6 @@ #include #include -#include #include #include "io/io.h" // Path @@ -90,10 +89,10 @@ using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; -// If not empty, SampleFunc is called with the logits for the next token, which -// it may modify/overwrite, and its return value is the next generated token -// together with its probability. -using SampleFunc = std::function; +// If not empty, SampleFunc is called with the query_idx and logits for the +// next token, which it may modify/overwrite. It returns the next generated +// token together with its probability. +using SampleFunc = std::function; // If not empty, LayersOutputFunc is called for layer outputs, specified with: // - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence @@ -136,8 +135,7 @@ struct RuntimeConfig { // Sampling-related parameters. float temperature; // Temperature for sampling. - size_t top_k = 1; // Top-k for sampling. - std::mt19937* gen; // Random number generator used for sampling. + size_t top_k = 1; // Top-k for sampling. int verbosity; // Controls verbosity of printed messages. diff --git a/gemma/run.cc b/gemma/run.cc index 3915bf8..7e2059f 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -98,9 +97,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, size_t prompt_size = 0; const ModelConfig& config = gemma.Config(); - std::mt19937 gen; - InitGenerator(inference, gen); - const bool have_image = !inference.image_file.path.empty(); Image image; const size_t pool_dim = config.vit_config.pool_dim; @@ -117,8 +113,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, HWY_ASSERT(image.ReadPPM(inference.image_file.path)); const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, + RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .use_spinning = threading.spin}; double image_tokens_start = hwy::platform::Now(); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, @@ -188,8 +183,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // Set up runtime config. TimingInfo timing_info = {.verbosity = inference.verbosity}; - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, + RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .batch_stream_token = batch_stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); @@ -239,7 +233,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // Prepare for the next turn. Works only for PaliGemma. if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) { abs_pos = 0; // Start a new turn at position 0. - InitGenerator(inference, gen); } else { // The last token was either EOS, then it should be ignored because it is // never part of the dialog, see Table 5 in the Gemma-2 paper: diff --git a/gemma/vit.cc b/gemma/vit.cc index 96d6d7f..1910091 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -110,8 +110,7 @@ class VitAttention { CallMatMul(Q, K, nullptr, env_, C); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { - float* HWY_RESTRICT c = C.Row(task); - Softmax(c, C.Cols(), env_.ctx.profiler, worker); + Softmax(C.RowSpan(task), env_.ctx.profiler, worker); }); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { @@ -154,7 +153,7 @@ class VitAttention { head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len, env_.ctx.profiler, worker); + Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = activations_.attention.att_out.Row(token) + head * qkv_dim; diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 8afb220..3cb565c 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -812,7 +812,7 @@ class DotStats { // Forward relative error, lower is better. void CheckRel() const { - ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 4E-3); + ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 7E-3); ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 1.23f); // Compensated and Double are very accurate. @@ -822,22 +822,22 @@ class DotStats { ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f); // Naive and OnlyTwoProd are considerably higher, but not huge. - ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2); + ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 3.5E-1); ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(), - 0.072); + 7.5E-2); // Kahan (FastTwoSum) is decent: - ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3); + ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 1E-2); ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f); // TwoProducts and TwoSums are a bit better. ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(), - 3E-3); - ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f); + 1.1E-2); + ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 1.0f); ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(), - 2.6E-3); + 1.1E-2); - ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2); + ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 5.2E-2); // Extremely high error on aarch64. ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f); } @@ -857,7 +857,7 @@ class DotStats { ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f); // But TwoProducts/TwoSums help a bit. - ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f); + ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 1.0f); ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f); // Extremely high error on aarch64. @@ -893,7 +893,7 @@ class DotStats { }; // Returns normalized value in [-1, 1). -float RandomFloat(std::mt19937& rng) { +float RandomFloat(RngStream& rng) { const uint32_t exp = hwy::BitCastScalar(1.0f); const uint32_t mantissa_mask = hwy::MantissaMask(); const uint32_t representation = exp | (rng() & mantissa_mask); @@ -908,7 +908,7 @@ float RandomFloat(std::mt19937& rng) { // error from the Dot algorithms, not the compression. template void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw, - std::mt19937& rng, + RngStream& rng, const PackedSpan& packed, CompressWorkingSet& work) { std::uniform_int_distribution e_dist(0, 6); @@ -934,7 +934,7 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw, // Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf. template double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v, - std::mt19937& rng) { + RngStream& rng) { PROFILER_FUNC; const size_t half = HWY_MAX(1, num / 2); // generate at least one random HWY_DASSERT(half != 0); @@ -1002,8 +1002,8 @@ struct TestShortDotsT { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); CompressWorkingSet work; - std::mt19937 rng; - rng.seed(12345); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); hwy::Stats s_l1[kVariants]; @@ -1108,9 +1108,10 @@ void TestAllDot() { { // ensure no profiler zones are active const hn::ScalableTag df; - std::mt19937 rngs[kMaxWorkers]; + AesCtrEngine engine(/*deterministic=*/true); + RngStream rngs[kMaxWorkers]; for (size_t i = 0; i < kMaxWorkers; ++i) { - rngs[i].seed(12345 + 65537 * i); + rngs[i] = RngStream(engine, i); } constexpr size_t kReps = hn::AdjustedReps(40); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 19a39aa..18ee40f 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -29,7 +29,7 @@ #include "ops/matmul.h" #include "util/allocator.h" -#include "util/basics.h" // TokenAndProb +#include "util/basics.h" // TokenAndProb, RngStream #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" @@ -614,12 +614,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( } // See below for a specialized version for top-1 sampling. -static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - hwy::Profiler& p, const size_t worker, +static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, + const size_t worker, float temperature = 1.0f) { static const auto zone = p.AddZone("Ops.Softmax"); PROFILER_ZONE3(p, worker, zone); - HWY_DASSERT(size != 0); + HWY_DASSERT(logits.size() != 0); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -629,24 +629,25 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const V vmin = hn::Set(d, hwy::LowestValue()); V vmax = vmin; V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly - hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR { - *pmax = hn::Max(*pmax, value); - }); + hn::Foreach(d, logits.data(), logits.size(), vmin, + [pmax](const auto d, const V value) + HWY_ATTR { *pmax = hn::Max(*pmax, value); }); vmax = hn::MaxOfLanes(d, vmax); // Subtract max (avoid precision loss for large exponents) and exponentiate. - hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR { - if constexpr (HWY_TARGET & HWY_ALL_SVE) { - // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). - return hn::CallExp(d, hn::Sub(value, *pmax)); - } else { - return hn::Exp(d, hn::Sub(value, *pmax)); - } - }); + hn::Transform(d, logits.data(), logits.size(), + [pmax](const auto d, const V value) HWY_ATTR { + if constexpr (HWY_TARGET & HWY_ALL_SVE) { + // Workaround for buggy SVE codegen: avoid inlined Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); + } else { + return hn::Exp(d, hn::Sub(value, *pmax)); + } + }); if (temperature != 1.0f) { const float temperature_inv = 1.0f / temperature; - hn::Transform(d, x, size, + hn::Transform(d, logits.data(), logits.size(), [temperature_inv](const auto d, const V value) HWY_ATTR { return hn::Mul(value, hn::Set(d, temperature_inv)); }); @@ -656,10 +657,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // not make a huge difference. It halves the standard deviation of the sum of // the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the generated text after a few hundred tokens. - const float sum_exp = Sum(d, x, size); + const float sum_exp = Sum(d, logits.data(), logits.size()); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, x, size, p, worker); + MulByConst(mul, logits.data(), logits.size(), p, worker); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / @@ -669,8 +670,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // which already knows the max value which top-1 sampling would again seek. // Returns the argmax and x[argmax]. -static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, - const size_t num) { +static HWY_INLINE TokenAndProb ArgmaxAndMax(Logits logits) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; using V = hn::Vec; @@ -680,16 +680,16 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, using TI = hn::TFromD; using VI = hn::Vec; const size_t N = hn::Lanes(d); - HWY_ASSERT(num % (2 * N) == 0); + HWY_ASSERT(logits.size() % (2 * N) == 0); V max0 = hn::Set(d, hwy::LowestValue()); V max1 = max0; VI argmax0 = hn::Zero(di); VI argmax1 = argmax0; - for (size_t i = 0; i < num; i += 2 * N) { - const V v0 = hn::LoadU(d, x + i); - const V v1 = hn::LoadU(d, x + i + N); + for (size_t i = 0; i < logits.size(); i += 2 * N) { + const V v0 = hn::LoadU(d, &logits[i]); + const V v1 = hn::LoadU(d, &logits[i + N]); const VI vi0 = hn::Iota(di, static_cast(i)); const VI vi1 = hn::Iota(di, static_cast(i + N)); const M gt0 = hn::Gt(v0, max0); @@ -714,43 +714,43 @@ static HWY_INLINE TokenAndProb ArgmaxAndMax(const float* HWY_RESTRICT x, return TokenAndProb{.token = argmax, .prob = hn::GetLane(max)}; } -// Returns argmax of softmax and its probability. This overwrites `x`, but not -// with normalized probabilities. Only equivalent to `Softmax` + `sample_func` -// if `kTopK` == 1. This is worthwhile because `num` is typically `kVocabSize` -// == 256K, and this avoids writing and then scanning again for the max. -// However, this is not enough to make parallelization worthwhile. -static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, - const size_t num) { +// Returns argmax of softmax and its probability. This overwrites `logits`, but +// not with normalized probabilities. Only equivalent to `Softmax` + +// `sample_func` if `kTopK` == 1. This is worthwhile because `logits.size()` is +// typically `kVocabSize == 256K`, and this avoids writing and then scanning +// again for the max. +static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) { namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag d; using V = hn::Vec; - const TokenAndProb argmax = ArgmaxAndMax(x, num); + const TokenAndProb argmax = ArgmaxAndMax(logits); // Subtract max (avoid precision loss for large exponents) and exponentiate. const V max = hn::Set(d, argmax.prob); const V* pmax = &max; - hn::Transform(d, x, num, [pmax](const auto d, const V value) HWY_ATTR { - if constexpr (HWY_TARGET & HWY_ALL_SVE) { - // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). - return hn::CallExp(d, hn::Sub(value, *pmax)); - } else { - return hn::Exp(d, hn::Sub(value, *pmax)); - } - }); + hn::Transform(d, logits.data(), logits.size(), + [pmax](const auto d, const V value) HWY_ATTR { + if constexpr (HWY_TARGET & HWY_ALL_SVE) { + // Temporary workaround for buggy SVE codegen: avoid inlined + // Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); + } else { + return hn::Exp(d, hn::Sub(value, *pmax)); + } + }); // Normalize to a single probability. The exact sum seems like it should not // make a huge difference. It halves the standard deviation of the sum of the // normalized probabilities from 1E-7 to 5E-8, but actually also changes the // generated text after a few hundred tokens. - const float sum_exp = Sum(d, x, num); - const float prob = x[argmax.token] / sum_exp; + const float sum_exp = Sum(d, logits.data(), logits.size()); + const float prob = logits[argmax.token] / sum_exp; return TokenAndProb{.token = argmax.token, .prob = prob}; } -static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - const size_t size, hwy::Profiler& p, - const size_t worker) { +static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, + hwy::Profiler& p, const size_t worker) { static const auto zone = p.AddZone("Ops.LogitsSoftCap"); PROFILER_ZONE3(p, worker, zone); @@ -763,18 +763,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, const VF* HWY_RESTRICT pcap = &vcap; const VF* HWY_RESTRICT pinv_cap = &vinv_cap; - DecompressAndCompressInplace( - DF(), x, size, [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF { - return hn::Mul(*pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap))); - }); + DecompressAndCompressInplace(DF(), logits.data(), logits.size(), + [pcap, pinv_cap](DF d, VF v) HWY_ATTR -> VF { + return hn::Mul( + *pcap, hn::Tanh(d, hn::Mul(v, *pinv_cap))); + }); } // Calls LogitsSoftCap if cap != 0.0f. static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( - const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, - const size_t worker) { + const float cap, Logits logits, hwy::Profiler& p, const size_t worker) { if (cap != 0.0f) { - LogitsSoftCap(cap, x, size, p, worker); + LogitsSoftCap(cap, logits, p, worker); } } @@ -785,20 +785,18 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { - LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, - worker); + LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker); } }); } -static HWY_NOINLINE HWY_MAYBE_UNUSED size_t -SampleArgmax(const float* probabilities, size_t vocab_size) { +static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(Logits logits) { size_t max_index = 0; - float max_prob = probabilities[0]; - for (size_t i = 1; i < vocab_size; ++i) { - if (probabilities[i] > max_prob) { + float max_prob = logits[0]; + for (size_t i = 1; i < logits.size(); ++i) { + if (logits[i] > max_prob) { max_index = i; - max_prob = probabilities[i]; + max_prob = logits[i]; } } return max_index; @@ -828,16 +826,15 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution( template HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( - const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k, - TAcceptToken& accept_token) { + Logits logits, size_t k, TAcceptToken& accept_token) { HWY_ASSERT(k != 0); - HWY_ASSERT(k <= vocab_size); + HWY_ASSERT(k <= logits.size()); std::vector packed_token_probs; - for (int32_t i = 0; i < static_cast(vocab_size); ++i) { - if (accept_token && !accept_token(i, probabilities[i])) { + for (int32_t i = 0; i < static_cast(logits.size()); ++i) { + if (accept_token && !accept_token(i, logits[i])) { continue; } - packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i])); + packed_token_probs.push_back(PackTokenAndProb(i, logits[i])); } hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k, @@ -853,11 +850,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( } template -HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( - const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token) { - std::vector token_probs = - TopK(probabilities, vocab_size, k, accept_token); +HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k, + RngStream& gen, float temperature, + TAcceptToken& accept_token) { + std::vector token_probs = TopK(logits, k, accept_token); std::vector topk_indices(k); std::vector topk_probs(k); for (size_t i = 0; i < k; ++i) { @@ -869,14 +865,12 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( template HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( - const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token, - hwy::Profiler& p, size_t worker) { + Logits logits, size_t k, RngStream& gen, float temperature, + TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) { // Softmax and sample top-K is equivalent to taking the top-K logits and // sampling from the softmax of the top-K logits. The latter is faster as it // avoids computing the softmax of all logits. - std::vector token_logits = - TopK(logits, vocab_size, k, accept_token); + std::vector token_logits = TopK(logits, k, accept_token); std::vector topk_indices(k); std::vector topk_logits(k); for (size_t i = 0; i < token_logits.size(); ++i) { @@ -884,8 +878,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( topk_logits[i] = token_logits[i].prob; } - size_t mask = token_logits.size(); - Softmax(topk_logits.data(), mask, p, worker, temperature); + const size_t mask = token_logits.size(); + Softmax(Logits(topk_logits.data(), mask), p, worker, temperature); auto distribution = std::discrete_distribution( std::begin(topk_logits), std::begin(topk_logits) + mask); int topk_sampled_index = distribution(gen); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 7e63482..213fdd0 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -57,6 +57,12 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +static RngStream MakeRng() { + static AesCtrEngine engine(/*deterministic=*/true); + static uint64_t stream = 0; + return RngStream(engine, ++stream); +} + template struct ForeachCountAndMisalign { template @@ -304,7 +310,7 @@ class TestSoftmax { } SimpleSoftmax(e, count); - Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0); + Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0); T sum = 0.0f; for (size_t i = 0; i < count; ++i) { @@ -438,10 +444,9 @@ void TestRopeAndMulBy() { const size_t dim_qkv = config.layer_configs[0].qkv_dim; MatStorageT x("x", dim_qkv, ctx.allocator); - std::mt19937 gen; - gen.seed(0x12345678); + RngStream rng = MakeRng(); std::normal_distribution r{0.0, 5.0}; - auto random_float = [&r, &gen] { return r(gen); }; + auto random_float = [&r, &rng] { return r(rng); }; for (size_t i = 0; i < dim_qkv; ++i) { x.Row(0)[i] = random_float(); @@ -704,38 +709,34 @@ void TestSampleTopK() { hwy::Profiler& p = hwy::Profiler::Get(); const size_t worker = 0; const size_t kSize = 52; - std::vector logits(kSize); + std::vector logits_vec(kSize); + Logits logits(logits_vec.data(), kSize); // Create a vector going from -100 to -100+51=49 and take Softmax. std::iota(logits.begin(), logits.end(), -100.0f); - Softmax(logits.data(), kSize, p, worker); - std::mt19937 gen; - gen.seed(0x12345678); + Softmax(logits, p, worker); + RngStream rng = MakeRng(); float temperature = 1.0f; // SampleTopK<1> should return the argmax. std::function accept_token; - int sample = - SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token); + int sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token); EXPECT_EQ(sample, 51); // Last is largest. // Only accept even tokens, expect the last (largest) even index. accept_token = [](int i, float) { return i % 2 == 0; }; - sample = - SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token); + sample = SampleTopK(logits, /*k=*/1, rng, temperature, accept_token); EXPECT_EQ(sample, 50); // Last even index. // Reset the logits to a positive, increasing sequence and take Softmax. std::iota(logits.begin(), logits.end(), 1.0f); - Softmax(logits.data(), kSize, p, worker); + Softmax(logits, p, worker); // Sample from the top 3, expect one of the top 3 even indices. for (int i = 0; i < 100; ++i) { - sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, - accept_token); + sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token); EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46); } // Now set the temperature to 0.0f, which should always return the argmax, // even for k=3. temperature = 0.0f; for (int i = 0; i < 100; ++i) { - sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, - accept_token); + sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token); EXPECT_EQ(sample, 50); } } diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index 2c798b9..449ee00 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -27,42 +27,38 @@ void PaliGemmaHelper::InitVit(const std::string& path) { HWY_ASSERT(image.ReadPPM(path)); const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.gen = &env_->MutableGen(), - .verbosity = 0}; + RuntimeConfig runtime_config = {.verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(), image, *image_tokens_, env_->MutableEnv()); } std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { const Gemma& model = *(env_->GetGemma()); - env_->MutableGen().seed(0x12345678); - std::string response; - auto stream_token = [&](int token, float) { - std::string token_text; - HWY_ASSERT( - model.Tokenizer().Decode(std::vector{token}, &token_text)); - response += token_text; - return true; - }; + std::string response; + auto stream_token = [&](int token, float) { + std::string token_text; + HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); + response += token_text; + return true; + }; - std::string mutable_prompt = prompt_text; - std::vector tokens = env_->WrapAndTokenize(mutable_prompt); - tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); + std::string mutable_prompt = prompt_text; + std::vector tokens = env_->WrapAndTokenize(mutable_prompt); + tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); - RuntimeConfig runtime_config = {.max_generated_tokens = 512, - // PrefixLM sees/attends to all tokens. - .prefill_tbatch_size = tokens.size(), - .gen = &env_->MutableGen(), - .verbosity = 0, - .stream_token = stream_token, - .image_tokens = image_tokens_.get()}; + RuntimeConfig runtime_config = {.max_generated_tokens = 512, + // PrefixLM sees/attends to all tokens. + .prefill_tbatch_size = tokens.size(), + .verbosity = 0, + .stream_token = stream_token, + .image_tokens = image_tokens_.get()}; - const size_t prefix_end = tokens.size(); - TimingInfo timing_info = {.verbosity = 0}; - model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, - env_->MutableKVCache(), env_->MutableEnv(), timing_info); - return response; + const size_t prefix_end = tokens.size(); + TimingInfo timing_info = {.verbosity = 0}; + model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, + env_->MutableKVCache(), env_->MutableEnv(), timing_info); + return response; } } // namespace gcpp diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 9af07b3..2e39f68 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -53,9 +53,8 @@ class GemmaModel { // Generates a single example, given a prompt and a callback to stream the // generated tokens. void GenerateEx(std::string prompt, gcpp::StreamFunc stream, - size_t max_generated_tokens, float temperature, float seed, - gcpp::AcceptFunc accept, bool skip_prompt) { - env_.MutableGen().seed(seed); + size_t max_generated_tokens, float temperature, + float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) { std::vector prompt_tokens = env_.WrapAndTokenize(prompt); gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; @@ -77,7 +76,7 @@ class GemmaModel { // Generates a single example, given a prompt, and returns the result. std::string Generate(std::string prompt, size_t max_generated_tokens, - float temperature, float seed, + float temperature, float /*seed*/, const std::vector& accept, const std::vector& end) { std::set end_token_set{}; @@ -124,7 +123,6 @@ class GemmaModel { } }; - env_.MutableGen().seed(seed); gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; @@ -144,14 +142,13 @@ class GemmaModel { // results. std::vector GenerateBatch(const std::vector& inputs, size_t max_generated_tokens, - float temperature, float seed, + float temperature, float /*seed*/, size_t top_k) { gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; config.top_k = top_k; config.verbosity = 0; - env_.MutableGen().seed(seed); std::vector outputs = env_.BatchQueryModel(inputs); std::vector result; @@ -187,8 +184,7 @@ class GemmaModel { "image_tokens", gcpp::Extents2D(config.vit_config.seq_len, config.model_dim), env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd)); - gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(), - .verbosity = 0}; + gcpp::RuntimeConfig runtime_config = {.verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(), c_image, *image_tokens_, env_.MutableEnv()); } @@ -197,10 +193,9 @@ class GemmaModel { // Uses the prompt_tokens if provided, otherwise tokenizes the prompt string. std::pair> GenerateWithImage( std::string prompt, size_t max_generated_tokens, float temperature, - float seed, gcpp::AcceptFunc accept, std::vector prompt_tokens) { + float /*seed*/, gcpp::AcceptFunc accept, std::vector prompt_tokens) { if (!image_tokens_) throw std::invalid_argument("No image set."); const gcpp::Gemma& model = *env_.GetGemma(); - env_.MutableGen().seed(seed); gcpp::RuntimeConfig& config = env_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; config.temperature = temperature; @@ -273,6 +268,7 @@ PYBIND11_MODULE(gemma, mod) { }), py::arg("tokenizer_path"), py::arg("weights_path"), py::arg("max_threads") = 0) + // seed arguments are ignored. .def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"), py::arg("stream"), py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9, py::arg("seed") = 123456789, diff --git a/util/basics.cc b/util/basics.cc index 4261510..d9fbc27 100644 --- a/util/basics.cc +++ b/util/basics.cc @@ -24,7 +24,7 @@ namespace gcpp { -RNG::RNG(bool deterministic) { +AesCtrEngine::AesCtrEngine(bool deterministic) { // Pi-based nothing up my sleeve numbers from Randen. key_[0] = 0x243F6A8885A308D3ull; key_[1] = 0x13198A2E03707344ull; @@ -54,9 +54,10 @@ static V Load(const uint64_t* ptr) { return hn::Load(D(), reinterpret_cast(ptr)); } -RNG::result_type RNG::operator()() { - V state = Load(counter_); - counter_[0]++; +uint64_t AesCtrEngine::operator()(uint64_t stream, uint64_t counter) const { + const hn::Repartition d64; + + V state = hn::BitCast(D(), hn::Dup128VecFromValues(d64, counter, stream)); state = hn::Xor(state, Load(key_)); // initial whitening static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t)); @@ -68,7 +69,6 @@ RNG::result_type RNG::operator()() { state = hn::AESRound(state, Load(key_ + 10)); // Return lower 64 bits of the u8 vector. - const hn::Repartition d64; return hn::GetLane(hn::BitCast(d64, state)); } diff --git a/util/basics.h b/util/basics.h index 2429c72..7b1c7d3 100644 --- a/util/basics.h +++ b/util/basics.h @@ -20,7 +20,7 @@ #include #include -#include "hwy/aligned_allocator.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // IWYU pragma: end_exports @@ -120,39 +120,60 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end, return IndexRange(begin, HWY_MIN(begin + max_size, end)); } +using Logits = hwy::Span; // size() is vocab_size. + // Non-cryptographic 64-bit pseudo-random number generator. Supports random or -// deterministic seeding. Conforms to C++ `UniformRandomBitGenerator`. +// deterministic seeding. // // Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This // is useful for parallel sampling. Each thread can generate the stream for a // particular task, without caring about prior/subsequent generations. -class alignas(16) RNG { +class alignas(16) AesCtrEngine { // "Large-scale randomness study of security margins for 100+ cryptographic // functions": at least four. // "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant. static constexpr size_t kRounds = 5; public: - explicit RNG(bool deterministic); + // If `deterministic` is true, uses a fixed seed; otherwise, attempts to + // grab entropy from the OS. + explicit AesCtrEngine(bool deterministic); - void SetStream(uint64_t stream) { - counter_[1] = stream; - counter_[0] = 0; - } + // Pure and thread safe; typically called via `RngStream`, which increments + // `counter`. Throughput is about 100M/s on 3 GHz Skylake. It could be + // increased 4x via unrolling by the AES latency (4-7 cycles), but because + // users generally call once at a time, this requires buffering, which is not + // worth the complexity in this application. + uint64_t operator()(uint64_t stream, uint64_t counter) const; + + private: + uint64_t key_[2 * (1 + kRounds)]; +}; + +// Flyweight per-thread adapter that maintains the counter. Conforms to C++ +// `UniformRandomBitGenerator`. +class RngStream { + public: + RngStream() = default; // Allow C arrays with subsequent initialization. + + // Binds to an engine, which holds the seed and must outlive this object. + // Sets the stream; any other `RngStream` with the same `counter_rng` and + // `stream` will return the same sequence. This is typically the task ID, so + // that threads can independently generate values for each task. + RngStream(const AesCtrEngine& counter_rng, uint64_t stream) + : engine_(&counter_rng), stream_(stream), counter_(0) {} using result_type = uint64_t; static constexpr result_type min() { return 0; } static constexpr result_type max() { return ~result_type{0}; } - - // About 100M/s on 3 GHz Skylake. Throughput could be increased 4x via - // unrolling by the AES latency (4-7 cycles). `std::discrete_distribution` - // makes individual calls to the generator, which would require buffering, - // which is not worth the complexity. - result_type operator()(); + result_type operator()() { return (*engine_)(stream_, counter_++); } private: - uint64_t counter_[2] = {}; - uint64_t key_[2 * (1 + kRounds)]; + const AesCtrEngine* engine_ = nullptr; + uint64_t stream_ = 0; // immutable after ctor + uint64_t counter_ = 0; + // Prevent false sharing if used by multiple threads. + HWY_MAYBE_UNUSED uint8_t padding_[HWY_ALIGNMENT - 16 - sizeof(engine_)]; }; } // namespace gcpp diff --git a/util/basics_test.cc b/util/basics_test.cc index 169d051..a1d805b 100644 --- a/util/basics_test.cc +++ b/util/basics_test.cc @@ -25,9 +25,11 @@ namespace gcpp { namespace { -TEST(BasicsTest, IsDeterministic) { - RNG rng1(/*deterministic=*/true); - RNG rng2(/*deterministic=*/true); +TEST(BasicsTest, EngineIsDeterministic) { + const AesCtrEngine engine1(/*deterministic=*/true); + const AesCtrEngine engine2(/*deterministic=*/true); + RngStream rng1(engine1, 0); + RngStream rng2(engine2, 0); // Remember for later testing after resetting the stream. const uint64_t r0 = rng1(); const uint64_t r1 = rng1(); @@ -42,15 +44,17 @@ TEST(BasicsTest, IsDeterministic) { HWY_ASSERT(rng1() == rng2()); } - // Reset counter, ensure it matches the default-constructed RNG. - rng1.SetStream(0); + // Reset counter, ensure it matches the prior sequence. + rng1 = RngStream(engine1, 0); HWY_ASSERT(r0 == rng1()); HWY_ASSERT(r1 == rng1()); } -TEST(BasicsTest, IsSeeded) { - RNG rng1(/*deterministic=*/true); - RNG rng2(/*deterministic=*/false); +TEST(BasicsTest, EngineIsSeeded) { + AesCtrEngine engine1(/*deterministic=*/true); + AesCtrEngine engine2(/*deterministic=*/false); + RngStream rng1(engine1, 0); + RngStream rng2(engine2, 0); // It would be very unlucky to have even one 64-bit value match, and two are // extremely unlikely. const uint64_t a0 = rng1(); @@ -60,9 +64,27 @@ TEST(BasicsTest, IsSeeded) { HWY_ASSERT(a0 != b0 || a1 != b1); } +TEST(BasicsTest, StreamsDiffer) { + AesCtrEngine engine(/*deterministic=*/true); + // Compare random streams for more coverage than just the first N streams. + RngStream rng_for_stream(engine, 0); + for (size_t i = 0; i < 1000; ++i) { + RngStream rng1(engine, rng_for_stream()); + RngStream rng2(engine, rng_for_stream()); + // It would be very unlucky to have even one 64-bit value match, and two are + // extremely unlikely. + const uint64_t a0 = rng1(); + const uint64_t a1 = rng1(); + const uint64_t b0 = rng2(); + const uint64_t b1 = rng2(); + HWY_ASSERT(a0 != b0 || a1 != b1); + } +} + // If not close to 50% 1-bits, the RNG is quite broken. TEST(BasicsTest, BitDistribution) { - RNG rng(/*deterministic=*/true); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); constexpr size_t kU64 = 2 * 1000 * 1000; const hwy::Timestamp t0; uint64_t one_bits = 0; @@ -78,7 +100,8 @@ TEST(BasicsTest, BitDistribution) { } TEST(BasicsTest, ChiSquared) { - RNG rng(/*deterministic=*/true); + AesCtrEngine engine(/*deterministic=*/true); + RngStream rng(engine, 0); constexpr size_t kU64 = 1 * 1000 * 1000; // Test each byte separately. diff --git a/util/mat.h b/util/mat.h index 9d838e2..c084e81 100644 --- a/util/mat.h +++ b/util/mat.h @@ -301,6 +301,13 @@ class MatPtrT : public MatPtr { return HWY_RCAST_ALIGNED(const T*, RowBytes(row)); } + hwy::Span RowSpan(size_t row) { + return hwy::Span(Row(row), Cols()); + } + hwy::Span RowSpan(size_t row) const { + return hwy::Span(Row(row), Cols()); + } + PackedSpan PaddedSpan() const { return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride()); } From 2b4c16e2438298f5d7a5ac3d59844fe6e7504981 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 5 Sep 2025 02:34:54 -0700 Subject: [PATCH 29/65] Remove Griffin support Also add IsObsolete helper PiperOrigin-RevId: 803376921 --- BUILD.bazel | 2 - CMakeLists.txt | 2 - README.md | 21 +- compression/python/compression_test.py | 4 +- compression/python/pytree/PYTREE_README.md | 8 - .../pytree/build_model_file_for_cpp_binary.py | 275 ---------- compression/python/pytree/cpp_load_log.txt | 380 ------------- .../python/pytree/ml_model_transforms.py | 371 ------------- .../python/pytree/ml_model_transforms_test.py | 92 ---- .../python/pytree/pytree_transforms.py | 508 ------------------ .../python/pytree/pytree_transforms_test.py | 168 ------ compression/python/pytree/requirements.txt | 4 - evals/gemma_test.cc | 3 - gemma/activations.h | 34 +- gemma/attention.cc | 6 +- gemma/configs.cc | 83 --- gemma/configs.h | 43 +- gemma/gemma.cc | 16 - gemma/griffin.cc | 192 ------- gemma/griffin.h | 47 -- gemma/kv_cache.cc | 38 +- gemma/kv_cache.h | 11 +- gemma/model_store.cc | 5 +- gemma/tensor_info.cc | 120 ----- gemma/tensor_info.h | 2 - gemma/weights.cc | 2 +- gemma/weights.h | 47 +- python/configs.cc | 9 - 28 files changed, 38 insertions(+), 2455 deletions(-) delete mode 100644 compression/python/pytree/PYTREE_README.md delete mode 100644 compression/python/pytree/build_model_file_for_cpp_binary.py delete mode 100644 compression/python/pytree/cpp_load_log.txt delete mode 100644 compression/python/pytree/ml_model_transforms.py delete mode 100644 compression/python/pytree/ml_model_transforms_test.py delete mode 100644 compression/python/pytree/pytree_transforms.py delete mode 100644 compression/python/pytree/pytree_transforms_test.py delete mode 100644 compression/python/pytree/requirements.txt delete mode 100644 gemma/griffin.cc delete mode 100644 gemma/griffin.h diff --git a/BUILD.bazel b/BUILD.bazel index cbfb342..52c2df3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -507,14 +507,12 @@ cc_library( srcs = [ "gemma/attention.cc", "gemma/gemma.cc", - "gemma/griffin.cc", "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", "gemma/attention.h", "gemma/gemma.h", - "gemma/griffin.h", "gemma/vit.h", ], exec_properties = { diff --git a/CMakeLists.txt b/CMakeLists.txt index d3a66fd..4bc0e80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,8 +83,6 @@ set(SOURCES gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h - gemma/griffin.cc - gemma/griffin.h gemma/kv_cache.cc gemma/kv_cache.h gemma/model_store.cc diff --git a/README.md b/README.md index 067051d..2963bf6 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ Guidelines](https://opensource.google.com/conduct/). - LLM - - CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2. + - CPU-only inference for: Gemma 2-3, PaliGemma 2. - Sampling with TopK and temperature. - Backward pass (VJP) and Adam optimizer for Gemma research. @@ -222,23 +222,6 @@ Example invocation for the following configuration: --tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs ``` -### RecurrentGemma - -This repository includes a version of Gemma based on Griffin -([paper](https://arxiv.org/abs/2402.19427), -[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture -includes both recurrent layers and local attention, thus it is more efficient -for longer sequences and has a smaller memory footprint than standard Gemma. We -here provide a C++ implementation of this model based on the paper. - -To use the recurrent version of Gemma included in this repository, build the -gemma binary as noted above in Step 3. Download the compressed weights and -tokenizer from the RecurrentGemma -[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in -Step 1, and run the binary as follows: - -`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs` - ### PaliGemma Vision-Language Model This repository includes a version of the PaliGemma 2 VLM @@ -535,7 +518,7 @@ gemma.cpp was started in fall 2023 by Griffin support was implemented in April 2024 thanks to contributions by Andrey Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas -Fischbacher and Zoltan Szabadka. +Fischbacher and Zoltan Szabadka. It was removed in 2025-09. Gemma-2 support was implemented in June/July 2024 with the help of several people. diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index e8244ed..957f0ec 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -91,7 +91,7 @@ class CompressionTest(absltest.TestCase): ) config = configs.ModelConfig( - configs.Model.GEMMA_TINY, + configs.Model.GEMMA2_2B, configs.Type.kSFP, configs.PromptWrapping.GEMMA_IT, ) @@ -101,7 +101,7 @@ class CompressionTest(absltest.TestCase): print("Ignore next two warnings; test does not enable model deduction.") reader = compression.SbsReader(temp_file.full_path) - self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY) + self.assertEqual(reader.config.model, configs.Model.GEMMA2_2B) self.assertEqual(reader.config.weight, configs.Type.kSFP) mat = reader.find_mat("tensor0") diff --git a/compression/python/pytree/PYTREE_README.md b/compression/python/pytree/PYTREE_README.md deleted file mode 100644 index 4a04079..0000000 --- a/compression/python/pytree/PYTREE_README.md +++ /dev/null @@ -1,8 +0,0 @@ - -# General Remarks about the "PyTree" Abstraction - -The pytree wrangling code in this project does not use any of the existing -"pytree" modules. The deeper reason here is that our approach is based on an -analysis of the notion that emphasizes deeper underlying principles. This is -being discussed internally at the time of this writing. - diff --git a/compression/python/pytree/build_model_file_for_cpp_binary.py b/compression/python/pytree/build_model_file_for_cpp_binary.py deleted file mode 100644 index d039639..0000000 --- a/compression/python/pytree/build_model_file_for_cpp_binary.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Ad-hoc glue code for building the griffin model-file for the C++ binary. - -Usage: - -python3 -m venv $HOME/clients/griffin-venv - -. $HOME/clients/griffin-venv/bin/activate - -python3 -m pip install -r requirements.txt - -time python3 build_model_file_for_cpp_binary.py \ - $HOME/GRIFFIN/model_data \ - cpp_load_log.txt /tmp/G2B.data - -real 3m5.821s -user 2m9.205s -sys 2m46.720s - -./compress_weights --weights /tmp/G2B.data --model gr2b-it \ - --compressed_weights /tmp/G2B.compressed -./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \ - --model gr2b-it - -Weights for the recurrent-gemma model that can be converted with this script -can be found at: - - https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it -""" - -import pprint -import re -import sys - -from typing import Any, Mapping - -import numpy - -import orbax.checkpoint - -import ml_model_transforms -import pytree_transforms - - -def _fn_identity(x): return x - - -def _fn_transpose(x): return x.T - - -def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1) - - -def _fn_scaled_softplus(a): - return -8 * numpy.logaddexp(a, 0) - - -def _fn_attention_moveaxis(a): - return a.reshape(10, 256, 2560).transpose(0, 2, 1) - - -def _aspec(pieces=(), transforms=()): - """Short-hand array-save-specification. - - Args: - pieces: Sequence of key-sequences identifying an array. - transforms: Sequence of transformations, indexed in - parallel to `pieces`, to apply to data arrays prior to saving. - Will be padded with identity-transformations to the length of `pieces`. - - Returns: - Specification as for use in _LAYETR_NAME_MAPPING. - """ - # `zip` trims to shortest sequence, so this amounts to using - # default-transforms. - # tuple() since we need a Sequence here, not a stateful-iterator zip_object. - return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces))) - - -_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({ - # Recurrent Layer - 'griffin_linear_x_w': _aspec( - [('recurrent_block', 'linear_x', 'kernel')], - [_fn_transpose]), - 'griffin_linear_x_biases': _aspec( - [('recurrent_block', 'linear_x', 'bias')]), - 'griffin_linear_y_w': _aspec( - [('recurrent_block', 'linear_y', 'kernel')], - [_fn_transpose]), - 'griffin_linear_y_biases': _aspec( - [('recurrent_block', 'linear_y', 'bias')]), - 'griffin_linear_out_w': _aspec( - [('recurrent_block', 'linear_out', 'kernel')], - [_fn_transpose]), - 'griffin_linear_out_biases': _aspec( - [('recurrent_block' ,'linear_out', 'bias')]), - 'griffin_conv_w': _aspec( - [('recurrent_block', 'conv_1d', 'w')]), - 'griffin_conv_biases': _aspec( - [('recurrent_block', 'conv_1d', 'b')]), - 'griffin_gate_w': _aspec( - [('recurrent_block', 'rg_lru', 'input_gate', 'w'), - ('recurrent_block', 'rg_lru', 'a_gate', 'w')], - [_fn_transpose_all_heads, _fn_transpose_all_heads]), - 'griffin_gate_biases': _aspec( - [('recurrent_block', 'rg_lru', 'input_gate', 'b'), - ('recurrent_block', 'rg_lru', 'a_gate', 'b')]), - 'griffin_a': _aspec( - [('recurrent_block', 'rg_lru', 'a_param')], - [_fn_scaled_softplus]), - # Attention Layer - 'qkv_einsum_w': _aspec( - [('attention_block', 'proj_q', 'kernel'), - ('attention_block', 'proj_k', 'kernel'), - ('attention_block', 'proj_v', 'kernel'), - ], - [_fn_transpose, _fn_transpose, _fn_transpose]), - 'attn_vec_einsum_w': _aspec( - [('attention_block', 'proj_final', 'kernel')], - [_fn_attention_moveaxis]), - 'attention_output_biases': _aspec( - [('attention_block', 'proj_final', 'bias')]), - # Common - 'pre_attention_norm_scale': _aspec( - [('temporal_pre_norm', 'scale')]), - 'pre_ffw_norm_scale': _aspec( - [('channel_pre_norm', 'scale')]), - 'gating_einsum_w': _aspec( - [('mlp_block', 'ffw_up', 'w')], - [_fn_transpose_all_heads]), - 'ffw_gating_biases': _aspec( - [('mlp_block', 'ffw_up', 'b')]), - 'linear_w': _aspec( - [('mlp_block', 'ffw_down', 'kernel')], - [_fn_transpose]), - 'ffw_output_biases': _aspec( - [('mlp_block', 'ffw_down', 'bias')]), - # Other - 'embedder_input_embedding': _aspec( - [('embedder', 'input_embedding')]), - 'final_norm_scale': _aspec( - [('final_norm', 'scale')]), -}) - - -def process_param_line(line : str) -> tuple[None | str, int, str]: - """Processes a "loading parameters" log-line from the griffin binary.""" - # This is slightly more permissive than strictly needed, to also handle - # some earlier form of the output. - matched = re.match( - r'(?a)Loading Parameters:? \(' - r'(?:layer=(?P\d+), )?' - r'size (?P\d+)\):? ' - r'(?P\S+)', - line) - if not matched: - return None - layer = matched['layer'] - wanted_size = int(matched['size']) - cpp_tag = matched['tag'] - return matched['layer'], int(matched['size']), matched['tag'] - - -def collect_pytree_keys(param_lines): - """Collects all the pytree keys and transforms for model-serialization.""" - pytree_keys = [] - array_transforms = [] - unsatisfied = [] - for maybe_spec in map(process_param_line, param_lines): - if not maybe_spec: continue # Skip non-parameter lines. - layer, wanted_size, cpp_tag = maybe_spec - pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ()) - if not pytree_key_tails_and_transforms: - unsatisfied.append((layer, cpp_tag)) - else: - for key_tail, array_transform in pytree_key_tails_and_transforms: - pytree_keys.append( - key_tail if layer is None - else (f'blocks.{layer}',) + key_tail) - array_transforms.append(array_transform) - return pytree_keys, array_transforms, unsatisfied - - -class UnsatisfiedArrayLoadsError(ValueError): - """Some array-loads could not be satisfied.""" - - -def flatten_model_for_cpp_binary(tree, - cpp_expectations_logfile_path : str, - out_path : str, - unsatisfied_ok : bool = False - ): - """Produces a model-parameters file readable by the C++ binary. - - Args: - tree: The pytree with model-parameters. - cpp_expectations_logfile_path: - Path to a logfile produced by the C++ binary that shows - the expected array-order. - out_path: Path to the model-weights file to be written. - unsatisfied_ok: If true, we ignore the presence of unsatisfied - array-loads and write a model-parameters file that skips these pieces. - This will lead to an unusable model-parameters file which however - still might be useful for other analysis. - - Returns: - Tuple `(unknown_keys, missing_keys)`, where `unknown_keys` - is a sequence of `(layer_or_None, name)` descriptions of the keys - in the C++ log that could not be satisfied, and `missing_keys` - is a sequence of linearized pytree key-sequences for keys - not found in the checkpoint. - - Raises: - UnsatisfiedArrayLoadsError: If some of the expected arrays - could not be included in the output and `unsatisfied_ok` - is false. - """ - with open(cpp_expectations_logfile_path, 'rt') as h_log: - pytree_keys, array_transforms, unknown_keys = collect_pytree_keys( - list(h_log)) - rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)} - array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms)) - # - model_contents = ml_model_transforms.model_contents(tree) - missing_keys = set(pytree_keys) - model_contents.keys() - if (unknown_keys or missing_keys) and not unsatisfied_ok: - raise ValueError( - f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, ' - f'missing keys: {sorted(missing_keys)!r}') - ml_model_transforms.model_save( - tree, - filepath_stem=out_path, - data_suffix='', - manifest_suffix=None, - array_transform_by_pytree_key=array_transform_by_pytree_key, - key=rank_by_pytree_key.get, - report=lambda line: print(line, file=sys.stderr), - byte_align=1) - return tuple(unknown_keys), tuple(sorted(missing_keys)) - - -def main(args): - """Creates the model-file. - - Args: - sys.argv[] parameters from command line sans the leading one. - - Returns: - The pytree with all the de-serialized variables, such as for convenient - `python3 -i` inspection. - """ - try: - model_dir, cpp_load_log, out_path = args - except Exception: - sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]') - pattern = ("recurrent", "recurrent", "attention") - orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() - variables = orbax_checkpointer.restore(model_dir) - if sorted(variables) == ['params']: - print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr) - variables_to_use = variables['params'] - else: - variables_to_use = variables - unknown, missing = flatten_model_for_cpp_binary(variables_to_use, - cpp_load_log, - out_path, - unsatisfied_ok=True) - print('Model file saved.\n' - f'# unknown:\n{pprint.pformat(unknown)}\n' - f'# missing:\n{pprint.pformat(missing)}') - return variables - - -if __name__ == '__main__': - # Return value assignment is for `python3 -i ...` inspection. - pytree = main(sys.argv[1:]) diff --git a/compression/python/pytree/cpp_load_log.txt b/compression/python/pytree/cpp_load_log.txt deleted file mode 100644 index cc33394..0000000 --- a/compression/python/pytree/cpp_load_log.txt +++ /dev/null @@ -1,380 +0,0 @@ -Loading Parameters (size 2622750720): embedder_input_embedding -Loading Parameters (size 10240): final_norm_scale -Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=0, size 40960) griffin_conv_w -Loading Parameters: (layer=0, size 10240) griffin_conv_biases -Loading Parameters: (layer=0, size 5242880) griffin_gate_w -Loading Parameters: (layer=0, size 20480) griffin_gate_biases -Loading Parameters: (layer=0, size 10240) griffin_a -Loading Parameters: (layer=0, size 157286400) gating_einsum_w -Loading Parameters: (layer=0, size 78643200) linear_w -Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=0, size 61440) ffw_gating_biases -Loading Parameters: (layer=0, size 10240) ffw_output_biases -Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=1, size 40960) griffin_conv_w -Loading Parameters: (layer=1, size 10240) griffin_conv_biases -Loading Parameters: (layer=1, size 5242880) griffin_gate_w -Loading Parameters: (layer=1, size 20480) griffin_gate_biases -Loading Parameters: (layer=1, size 10240) griffin_a -Loading Parameters: (layer=1, size 157286400) gating_einsum_w -Loading Parameters: (layer=1, size 78643200) linear_w -Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=1, size 61440) ffw_gating_biases -Loading Parameters: (layer=1, size 10240) ffw_output_biases -Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=2, size 78643200) qkv_einsum_w -Loading Parameters: (layer=2, size 157286400) gating_einsum_w -Loading Parameters: (layer=2, size 78643200) linear_w -Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=2, size 61440) ffw_gating_biases -Loading Parameters: (layer=2, size 10240) ffw_output_biases -Loading Parameters: (layer=2, size 10240) attention_output_biases -Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=3, size 40960) griffin_conv_w -Loading Parameters: (layer=3, size 10240) griffin_conv_biases -Loading Parameters: (layer=3, size 5242880) griffin_gate_w -Loading Parameters: (layer=3, size 20480) griffin_gate_biases -Loading Parameters: (layer=3, size 10240) griffin_a -Loading Parameters: (layer=3, size 157286400) gating_einsum_w -Loading Parameters: (layer=3, size 78643200) linear_w -Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=3, size 61440) ffw_gating_biases -Loading Parameters: (layer=3, size 10240) ffw_output_biases -Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=4, size 40960) griffin_conv_w -Loading Parameters: (layer=4, size 10240) griffin_conv_biases -Loading Parameters: (layer=4, size 5242880) griffin_gate_w -Loading Parameters: (layer=4, size 20480) griffin_gate_biases -Loading Parameters: (layer=4, size 10240) griffin_a -Loading Parameters: (layer=4, size 157286400) gating_einsum_w -Loading Parameters: (layer=4, size 78643200) linear_w -Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=4, size 61440) ffw_gating_biases -Loading Parameters: (layer=4, size 10240) ffw_output_biases -Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=5, size 78643200) qkv_einsum_w -Loading Parameters: (layer=5, size 157286400) gating_einsum_w -Loading Parameters: (layer=5, size 78643200) linear_w -Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=5, size 61440) ffw_gating_biases -Loading Parameters: (layer=5, size 10240) ffw_output_biases -Loading Parameters: (layer=5, size 10240) attention_output_biases -Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=6, size 40960) griffin_conv_w -Loading Parameters: (layer=6, size 10240) griffin_conv_biases -Loading Parameters: (layer=6, size 5242880) griffin_gate_w -Loading Parameters: (layer=6, size 20480) griffin_gate_biases -Loading Parameters: (layer=6, size 10240) griffin_a -Loading Parameters: (layer=6, size 157286400) gating_einsum_w -Loading Parameters: (layer=6, size 78643200) linear_w -Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=6, size 61440) ffw_gating_biases -Loading Parameters: (layer=6, size 10240) ffw_output_biases -Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=7, size 40960) griffin_conv_w -Loading Parameters: (layer=7, size 10240) griffin_conv_biases -Loading Parameters: (layer=7, size 5242880) griffin_gate_w -Loading Parameters: (layer=7, size 20480) griffin_gate_biases -Loading Parameters: (layer=7, size 10240) griffin_a -Loading Parameters: (layer=7, size 157286400) gating_einsum_w -Loading Parameters: (layer=7, size 78643200) linear_w -Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=7, size 61440) ffw_gating_biases -Loading Parameters: (layer=7, size 10240) ffw_output_biases -Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=8, size 78643200) qkv_einsum_w -Loading Parameters: (layer=8, size 157286400) gating_einsum_w -Loading Parameters: (layer=8, size 78643200) linear_w -Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=8, size 61440) ffw_gating_biases -Loading Parameters: (layer=8, size 10240) ffw_output_biases -Loading Parameters: (layer=8, size 10240) attention_output_biases -Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=9, size 40960) griffin_conv_w -Loading Parameters: (layer=9, size 10240) griffin_conv_biases -Loading Parameters: (layer=9, size 5242880) griffin_gate_w -Loading Parameters: (layer=9, size 20480) griffin_gate_biases -Loading Parameters: (layer=9, size 10240) griffin_a -Loading Parameters: (layer=9, size 157286400) gating_einsum_w -Loading Parameters: (layer=9, size 78643200) linear_w -Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=9, size 61440) ffw_gating_biases -Loading Parameters: (layer=9, size 10240) ffw_output_biases -Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=10, size 40960) griffin_conv_w -Loading Parameters: (layer=10, size 10240) griffin_conv_biases -Loading Parameters: (layer=10, size 5242880) griffin_gate_w -Loading Parameters: (layer=10, size 20480) griffin_gate_biases -Loading Parameters: (layer=10, size 10240) griffin_a -Loading Parameters: (layer=10, size 157286400) gating_einsum_w -Loading Parameters: (layer=10, size 78643200) linear_w -Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=10, size 61440) ffw_gating_biases -Loading Parameters: (layer=10, size 10240) ffw_output_biases -Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=11, size 78643200) qkv_einsum_w -Loading Parameters: (layer=11, size 157286400) gating_einsum_w -Loading Parameters: (layer=11, size 78643200) linear_w -Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=11, size 61440) ffw_gating_biases -Loading Parameters: (layer=11, size 10240) ffw_output_biases -Loading Parameters: (layer=11, size 10240) attention_output_biases -Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=12, size 40960) griffin_conv_w -Loading Parameters: (layer=12, size 10240) griffin_conv_biases -Loading Parameters: (layer=12, size 5242880) griffin_gate_w -Loading Parameters: (layer=12, size 20480) griffin_gate_biases -Loading Parameters: (layer=12, size 10240) griffin_a -Loading Parameters: (layer=12, size 157286400) gating_einsum_w -Loading Parameters: (layer=12, size 78643200) linear_w -Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=12, size 61440) ffw_gating_biases -Loading Parameters: (layer=12, size 10240) ffw_output_biases -Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=13, size 40960) griffin_conv_w -Loading Parameters: (layer=13, size 10240) griffin_conv_biases -Loading Parameters: (layer=13, size 5242880) griffin_gate_w -Loading Parameters: (layer=13, size 20480) griffin_gate_biases -Loading Parameters: (layer=13, size 10240) griffin_a -Loading Parameters: (layer=13, size 157286400) gating_einsum_w -Loading Parameters: (layer=13, size 78643200) linear_w -Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=13, size 61440) ffw_gating_biases -Loading Parameters: (layer=13, size 10240) ffw_output_biases -Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=14, size 78643200) qkv_einsum_w -Loading Parameters: (layer=14, size 157286400) gating_einsum_w -Loading Parameters: (layer=14, size 78643200) linear_w -Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=14, size 61440) ffw_gating_biases -Loading Parameters: (layer=14, size 10240) ffw_output_biases -Loading Parameters: (layer=14, size 10240) attention_output_biases -Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=15, size 40960) griffin_conv_w -Loading Parameters: (layer=15, size 10240) griffin_conv_biases -Loading Parameters: (layer=15, size 5242880) griffin_gate_w -Loading Parameters: (layer=15, size 20480) griffin_gate_biases -Loading Parameters: (layer=15, size 10240) griffin_a -Loading Parameters: (layer=15, size 157286400) gating_einsum_w -Loading Parameters: (layer=15, size 78643200) linear_w -Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=15, size 61440) ffw_gating_biases -Loading Parameters: (layer=15, size 10240) ffw_output_biases -Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=16, size 40960) griffin_conv_w -Loading Parameters: (layer=16, size 10240) griffin_conv_biases -Loading Parameters: (layer=16, size 5242880) griffin_gate_w -Loading Parameters: (layer=16, size 20480) griffin_gate_biases -Loading Parameters: (layer=16, size 10240) griffin_a -Loading Parameters: (layer=16, size 157286400) gating_einsum_w -Loading Parameters: (layer=16, size 78643200) linear_w -Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=16, size 61440) ffw_gating_biases -Loading Parameters: (layer=16, size 10240) ffw_output_biases -Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=17, size 78643200) qkv_einsum_w -Loading Parameters: (layer=17, size 157286400) gating_einsum_w -Loading Parameters: (layer=17, size 78643200) linear_w -Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=17, size 61440) ffw_gating_biases -Loading Parameters: (layer=17, size 10240) ffw_output_biases -Loading Parameters: (layer=17, size 10240) attention_output_biases -Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=18, size 40960) griffin_conv_w -Loading Parameters: (layer=18, size 10240) griffin_conv_biases -Loading Parameters: (layer=18, size 5242880) griffin_gate_w -Loading Parameters: (layer=18, size 20480) griffin_gate_biases -Loading Parameters: (layer=18, size 10240) griffin_a -Loading Parameters: (layer=18, size 157286400) gating_einsum_w -Loading Parameters: (layer=18, size 78643200) linear_w -Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=18, size 61440) ffw_gating_biases -Loading Parameters: (layer=18, size 10240) ffw_output_biases -Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=19, size 40960) griffin_conv_w -Loading Parameters: (layer=19, size 10240) griffin_conv_biases -Loading Parameters: (layer=19, size 5242880) griffin_gate_w -Loading Parameters: (layer=19, size 20480) griffin_gate_biases -Loading Parameters: (layer=19, size 10240) griffin_a -Loading Parameters: (layer=19, size 157286400) gating_einsum_w -Loading Parameters: (layer=19, size 78643200) linear_w -Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=19, size 61440) ffw_gating_biases -Loading Parameters: (layer=19, size 10240) ffw_output_biases -Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=20, size 78643200) qkv_einsum_w -Loading Parameters: (layer=20, size 157286400) gating_einsum_w -Loading Parameters: (layer=20, size 78643200) linear_w -Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=20, size 61440) ffw_gating_biases -Loading Parameters: (layer=20, size 10240) ffw_output_biases -Loading Parameters: (layer=20, size 10240) attention_output_biases -Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=21, size 40960) griffin_conv_w -Loading Parameters: (layer=21, size 10240) griffin_conv_biases -Loading Parameters: (layer=21, size 5242880) griffin_gate_w -Loading Parameters: (layer=21, size 20480) griffin_gate_biases -Loading Parameters: (layer=21, size 10240) griffin_a -Loading Parameters: (layer=21, size 157286400) gating_einsum_w -Loading Parameters: (layer=21, size 78643200) linear_w -Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=21, size 61440) ffw_gating_biases -Loading Parameters: (layer=21, size 10240) ffw_output_biases -Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=22, size 40960) griffin_conv_w -Loading Parameters: (layer=22, size 10240) griffin_conv_biases -Loading Parameters: (layer=22, size 5242880) griffin_gate_w -Loading Parameters: (layer=22, size 20480) griffin_gate_biases -Loading Parameters: (layer=22, size 10240) griffin_a -Loading Parameters: (layer=22, size 157286400) gating_einsum_w -Loading Parameters: (layer=22, size 78643200) linear_w -Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=22, size 61440) ffw_gating_biases -Loading Parameters: (layer=22, size 10240) ffw_output_biases -Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=23, size 78643200) qkv_einsum_w -Loading Parameters: (layer=23, size 157286400) gating_einsum_w -Loading Parameters: (layer=23, size 78643200) linear_w -Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=23, size 61440) ffw_gating_biases -Loading Parameters: (layer=23, size 10240) ffw_output_biases -Loading Parameters: (layer=23, size 10240) attention_output_biases -Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=24, size 40960) griffin_conv_w -Loading Parameters: (layer=24, size 10240) griffin_conv_biases -Loading Parameters: (layer=24, size 5242880) griffin_gate_w -Loading Parameters: (layer=24, size 20480) griffin_gate_biases -Loading Parameters: (layer=24, size 10240) griffin_a -Loading Parameters: (layer=24, size 157286400) gating_einsum_w -Loading Parameters: (layer=24, size 78643200) linear_w -Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=24, size 61440) ffw_gating_biases -Loading Parameters: (layer=24, size 10240) ffw_output_biases -Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=25, size 40960) griffin_conv_w -Loading Parameters: (layer=25, size 10240) griffin_conv_biases -Loading Parameters: (layer=25, size 5242880) griffin_gate_w -Loading Parameters: (layer=25, size 20480) griffin_gate_biases -Loading Parameters: (layer=25, size 10240) griffin_a -Loading Parameters: (layer=25, size 157286400) gating_einsum_w -Loading Parameters: (layer=25, size 78643200) linear_w -Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=25, size 61440) ffw_gating_biases -Loading Parameters: (layer=25, size 10240) ffw_output_biases diff --git a/compression/python/pytree/ml_model_transforms.py b/compression/python/pytree/ml_model_transforms.py deleted file mode 100644 index 3605c07..0000000 --- a/compression/python/pytree/ml_model_transforms.py +++ /dev/null @@ -1,371 +0,0 @@ -"""Transformations for python-trees representing the parameters of a ML model. - -Important: This module assumes that byte-order is the same on the -machine that serializes data and the machine that deserializes -data. If, for example, numpy-data gets dumped, respectively loaded, -with a dtype-specification of numpy.float32, on-file byte-order -will be host byte order. - -""" - -import ast -import hashlib -import itertools -import pprint -import sys -import time -from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar - -import numpy -import pytree_transforms - - -NT = TypeVar('NT') - - -def ml_model_leaf_summary(path, x, sep=', '): - """Produces a textual summary of a leaf-node and its path. - - Args: - path: The path-to-root, as a reverse-order recursive - pair of path-components, with `()` as root. - x: The leaf-object. - sep: the separator between description-elements. - Default ', ' allows for convenient line-by-line processing - (such as via grep, perl -ne, etc.), but using e.g. sep=',\n ' - might be more useful for human consumption. - - Returns: - A human-readable string providing information about the node. - """ - # Using `repr` for path-components to get a faithful presentation. - # (...which still however would be somewat painful to correctly - # split into components.) - path_str = ','.join(map(repr, - pytree_transforms.linearize_revtuple_path(path))) - tx = type(x) - mod = tx.__module__ # Either a module or a string like 'builtins'. - modname = mod if isinstance(mod, str) else mod.__name__ - type_str = f'{modname}.{tx.__qualname__}' - try: - # `numpy.ndarray` instances have a `.data` property that gives access - # to a buffer via which we can hashlib-fingerprint the numerical - # contents. We here simply try to produce a fingerprint and also look - # up the .dtype of the object. Technically, there is a somewhat-unsound - # assumption here that if these operations succeed, we are indeed looking - # at a ndarray or sufficiently similar object for these operations to - # make sense. As the output is declared "for human consumption", this - # fishiness is not a problem. - fp = hashlib.sha256(x.data).hexdigest() - start = list(itertools.islice(x.flat, 5)) - stats_str = ( - f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, ' - f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}') - return (f'{path_str:60s}: <{type_str}{sep}' - f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, ' - f'dtype={x.dtype}{sep}start={start}>') - except (AttributeError, ValueError, TypeError): - # Fallback - trying to include information about the data-content - # of a likely-numerical-array failed. - return f'{path_str:60s}: {type_str}({repr(x)})' - - -# A specialized node-handler. -# Interface follows node-handler expectations defined in pytree_transforms. -def _ml_model_tree_node_handler(path: tuple, node : NT) -> ( - None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], - Iterator[tuple[Any, NT]]]): - """Processes a tree-node as required by pytree-iteration and -mapping. - - Args: - path: revtuple path to the current node. - node: a tree-node in a ML-model tree that is recursively - built out of `numpy.ndarray` leaf-values and dicts mapping - node-name string-keys to other such nodes representing subtrees - - and nothing else. - - Returns: - `None` if the tree-node is to be regarded as a leaf, otherwise - a pair `(rebuilder, iterator)`, where `iterator` iterates - over the data-content of the node, each item represented as a pair - of `(lookup_path_item, value_item)`, and `rebuilder` is a function - which, when applied to `iterator` or any iterable with the same - elements, returns a node that is equivalent to the original. - - Raises: - NotAMLModelTreeNodeError: If the tree contains a node that is neither - a `dict` nor a `numpy.ndarray` instance. - """ - # The astute reader will notice that we are doing something fishy - # here - this code could not be translated to Haskell as-is, since - # `NT` cannot actually be a proper type-variable in the sense - # of parametric polymorphism. - del path # Unused. - if isinstance(node, dict): - return dict, iter(node.items()) - if isinstance(node, numpy.ndarray): - return None - raise pytree_transforms.NotAMLModelTreeNodeError( - f'Type of bad node: {type(node)}') - - -def _ml_model_extract_leaf_transform( - path: pytree_transforms.RevTuplePath, - leaf: Any): - """Maps an array-leaf to a pair `(full_path, lambda: array)`. - - The computation that produces the leaf-value is lazified underneath - a `lambda`, since if we e.g. performed a memory-expensive - transformation (such as some dtype-changes) directly at this point, - then going from an iterator over tree-items for one-by-one - consumption to a list of these items would have all the - dtype-transformed values around simultaneously. We want to avoid - situations where we can do nothing about having multiple variants - of the data simultaneously in memory. - """ - # Hack: If we are encountering a `bfloat16` numpy-array, - # we pretend to have the data as a numpy.float32 array, - # since that's about all that contemporary CPUs can process - # efficiently here. - linearized_path = pytree_transforms.linearize_revtuple_path(path) - try: - # We have to use some trickery to detect `bfloat16`. - if leaf.dtype.descr[-1] == ('', ' Any: - """Performs perl-style autovivification on a nested-dict tree. - - Args: - keys_and_vals: An iterable of pairs `(key_path, value)`, where - `key_path` is a sequence of keys to be used to navigate to - the result via iterative dict-lookup, left-to-right. - Must not have duplicate keys, and must not more than one key if - an empty-sequence key is present. If this iterable is an - iterator, it will be fully exhausted on successful execution. - - Returns: - An object representing a nested-dict structure such that - for every `key_path` from `keys_and_vals`, recursive-dict-lookup - on the elements of that path starting from this object will - produce the corresponding value. An empty `keys_and_vals` - set will return `{}`. Every dict in the nested return-value - that has been populated by autovivification is newly allocated. - """ - # Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)]) - # having to evaluate to x and not a dict. - # There may be ways to prettify/simplify this. - result = None - empty = {} - for linear_path, val in keys_and_vals: - if linear_path == (): - if result is not None: - raise ValueError('Root-value seen alongside other values.') - result = val - else: - if result is None: - result = {} - elif type(result) is not dict: - # We already did encounter a root-value. - raise ValueError('Root-value seen alongside other values.') - cursor = result - for n in range(len(linear_path) - 1): - cursor = cursor.setdefault(linear_path[n], empty) - if cursor is empty: - # Regenerate `empty` if we just used it up. - empty = {} - cursor[linear_path[-1]] = val - return {} if result is None else result - - -def model_overview(tree, out=None) -> None: - """Prints a human-readable overview to `(out or sys.stdout)`.""" - actual_out = out or sys.stdout - for line in pytree_transforms.pytree_leaf_iter( - tree, ml_model_leaf_summary, - _ml_model_tree_node_handler): - print(line, file=actual_out) - - -def model_contents(tree) -> Mapping[tuple[str, ...], Any]: - """Maps a model to a {pytree_keys: data_array} mapping. - - Args: - tree: The ML-model parameter-tree, built recursively out of - dict-instances with numpy.ndarray instances as leaves. - - Returns: - A mapping from linearized pytree-key-sequence tuple to the corresponding - leaf-value. - """ - def leaf_transform(revtuple_path, leaf): - return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf - return dict( - pytree_transforms.pytree_leaf_iter( - tree, leaf_transform, _ml_model_tree_node_handler)) - - -def _fn_identity(x): return x - - -def model_save(tree, - filepath_stem: str, - data_suffix: str = '.data', - manifest_suffix: str | None = '.manifest', - key: Callable[[tuple[str, ...]], Any] | None = None, - array_transform_by_pytree_key: ( - Mapping[tuple[str, ...], - Callable[[numpy.ndarray], numpy.ndarray]] | - None) = None, - report: Callable[[str], None] | None = None, - byte_align: int = 8) -> tuple[int, float]: - """Saves the content of a ML-model parameter-tree to filesystem. - - After successful execution, the file f"{filepath_stem}.data" - will hold the combined numerical model-parameters, and - f"{filepath_stem}.manifest" will contain the key for interpreting - (and rebuilding) the data. - - Args: - tree: The ML-model parameter-tree, built recursively out of - dict-instances with numpy.ndarray instances as leaves. - filepath_stem: Filesystem location for data. - data_suffix: Suffix to use for the data file. - manifest_suffix: Either `None`, in which case no manifest-file - will get written, or the suffix for the manifest-file. - key: `None` or a key-function that will be applied to the linear model-path - and used for sorting the data arrays by increasing value of the - key-function. If the key-function returns `None` on an item, - then this item is not included. - array_transform_by_pytree_key: Optional mapping from pytree-key - to an array-to-array transformation function to apply to the array - prior to serialization. - report: Optional callable for logging progress-reports. - byte_align: byte-alignment to use for numerical array data. - Numerical arrays whose size in bytes is not a multiple of this - will get padded to the next full multiple. - - Returns: - A pair of `(size, time_sec)`, where `size` is the total byte-size - of the `.data` file and `time_sec` is the elapsed time - for saving the model, in seconds. - """ - time0 = time.monotonic() - if array_transform_by_pytree_key is None: - array_transform_by_pytree_key = {} - model_lazy_items = ( - pytree_transforms.pytree_leaf_iter( - tree, _ml_model_extract_leaf_transform, - _ml_model_tree_node_handler)) - if key is not None: - to_write = [ - nkv[1:] for nkv in sorted( - (nkv for nkv in ((key(path), path, v) - for path, v in model_lazy_items) - if nkv[0] is not None), key=lambda nkv: nkv[0])] - else: - to_write = list(model_lazy_items) - # - def lazy_arr_path_shape_dtype_size(path_and_lazy_arr): - path, lazy_arr = path_and_lazy_arr - arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr()) - return path, arr.shape, arr.dtype, arr.data.nbytes - arrs_path_shape_dtype_nbytes = list( - map(lazy_arr_path_shape_dtype_size, to_write)) - # We need to know the total size of all the data. - bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes] - padded_bytesizes = [-(-bytesize // byte_align * byte_align) - for bytesize in bytesizes] - offsets = numpy.cumsum([0] + padded_bytesizes) - membuf = numpy.memmap(filepath_stem + data_suffix, - mode='w+', shape=offsets[-1]) - try: - for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip( - arrs_path_shape_dtype_nbytes, offsets, to_write): - # Note that if getting the array from the lazy lambda involved some - # computation, such as a copying dtype-change, that computation would - # end up being done multiple times here - including once above, to compute - # byte-sizes, and once more here. - transformed_arr = array_transform_by_pytree_key.get( - path, - _fn_identity)(lazy_arr()) - membuf[offset : offset + nbytes] = numpy.frombuffer( - transformed_arr.ravel().data, 'u1') - if report is not None: - samples = ', '.join(map(str, transformed_arr.ravel()[:5])) - report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, ' - f'shape: {shape!r:30},\n start: [{samples}, ...]') - transformed_arr = None # Drop memory references to numerical arrays ASAP. - finally: - if membuf is not None: - membuf.flush() - # NumPy wart: the memory-buffer is a resource that conceptually - # should be .close()able - since mmap()ing holds on to a - # file descriptor. However, it looks as if that clean-up were done - # in the "finalizer", despite that having meanwhile been widely - # understood as dubious practice. So, the best we can do here is - # to explicitly and clearly remove our reference to the instance. - del membuf - if manifest_suffix is not None: - # We still have to serialize the data that allows us to reconstruct - # a tree that is equivalent to the original. - manifest_data = [ - dict(path=path, - dtype=dtype.descr[-1][-1], - shape=shape, - nbytes=nbytes, - offset=offset) - for (path, shape, dtype, nbytes), offset in zip( - arrs_path_shape_dtype_nbytes, offsets)] - with open(filepath_stem + '.manifest', 'wt') as h_manifest: - pprint.pprint(manifest_data, stream=h_manifest) - time_taken = time.monotonic() - time0 - return offsets[-1], time_taken - - -def model_load(filepath_stem, mmapped=True): - """Loads a model saved by `model_save`. - - Tries to load the model from f"{filepath_stem}.data" - and f"{filepath_stem}.manifest". - - Args: - filepath_stem: The model location on the filesystem. - mmapped: Whether data-arrays will be slices of a - `numpy.memmap` mapped buffer, to be paged in - on demand only, or in-memory copies of the data. - Returns: - A dict/numpy.ndarray tree representation of the model, - equivalent to the original model. - """ - with open(filepath_stem + '.manifest', 'rt') as h_manifest: - manifest = ast.literal_eval(h_manifest.read()) - membuf = numpy.memmap(filepath_stem + '.data', mode='r+') - paths_and_arrays = [] - for item in manifest: - path = item['path'] - dtype = numpy.dtype(item['dtype']) - shape = item['shape'] - nbytes = item['nbytes'] - offset = item['offset'] - data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data, - dtype=dtype).reshape(shape) - paths_and_arrays.append( - (path, - data_array if mmapped else data_array.copy())) - # At this point, the memory-buffer is no longer needed. Still, if - # data-arrays retain references to the underlying data - # (i.e. when mmapped=False), this should keep the mapping - # - and hence file descriptor - open. We then are in a somewhat - # undesirable situation of clean-up of a resource that happens in a - # hard-to-predict way releasing a file descriptor. - del membuf - return revtuple_autovifify_from_linear(paths_and_arrays) diff --git a/compression/python/pytree/ml_model_transforms_test.py b/compression/python/pytree/ml_model_transforms_test.py deleted file mode 100644 index 9495c87..0000000 --- a/compression/python/pytree/ml_model_transforms_test.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Basic tests for 'algebraic data type based pytree' transformations.""" - - -import io -import os -import tempfile -import unittest - -import numpy - -import ml_model_transforms - - -def _get_model(prefix): - return { - prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32), - prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32), - prefix + 'b1': { - prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8), - prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64) - }} - - -class MLModeltransformsTest(unittest.TestCase): - """Basic correctness validation tests for ML-model transformations.""" - - def test_ml_model_leaf_summary(self): - """Tests guarantees given by `ml_model_leaf_summary`.""" - summary = ml_model_transforms.ml_model_leaf_summary( - ('a', ()), - numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16), - sep='##') - self.assertIn('##', summary) # Separator is respected. - self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere. - self.assertIn('int16', summary) # dtype is mentioned somewhere. - - def test_revtuple_autovivify_from_linear(self): - """Tests guarantees given by `revtuple_autovifify_from_linear`.""" - with self.subTest(guarantee='empty'): - self.assertEqual( - ml_model_transforms.revtuple_autovifify_from_linear([]), - {}) - with self.subTest(guarantee='generic'): - keys_vals = [(('a', 'b1', 'c1'), 1001), - (('a', 'b2'), 1002), - (('a2',), 1003), - ] - self.assertEqual( - ml_model_transforms.revtuple_autovifify_from_linear(keys_vals), - {'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003}) - - def test_model_overview(self): - """Tests guarantees given by `model_overview`.""" - model = _get_model('xyz') - out_io = io.StringIO() - ml_model_transforms.model_overview(model, out=out_io) - overview = out_io.getvalue() - self.assertIn('xyz', overview) - - def test_model_contents(self): - """Tests guarantees given by `model_contents`.""" - model = _get_model('pq_') - contents = ml_model_transforms.model_contents(model) - fingerprints = {k: (a.shape, a.ravel()[:3].tolist()) - for k, a in contents.items()} - self.assertEqual(fingerprints, - {('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]), - ('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]), - ('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]), - ('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])}) - - def test_model_save_load_basic(self): - """Tests basic guarantees given by `model_save` and `model_load`.""" - # What we care about here is that the round trip works - so - # it makes more sense to test saving and loading as one unit. - model_orig = _get_model('model_') - with tempfile.TemporaryDirectory() as tempdir: - filepath_stem = os.path.join(tempdir, 'the_model') - total_size, total_time = ml_model_transforms.model_save(model_orig, - filepath_stem) - self.assertGreater(total_size, 0) - self.assertGreater(total_time, 0) - model_reloaded = ml_model_transforms.model_load(filepath_stem) - contents_orig = ml_model_transforms.model_contents(model_orig) - contents_reloaded = ml_model_transforms.model_contents(model_reloaded) - self.assertEqual( - {k: v.tolist() for k, v in contents_orig.items()}, - {k: v.tolist() for k, v in contents_reloaded.items()}) - - -if __name__ == '__main__': - unittest.main() diff --git a/compression/python/pytree/pytree_transforms.py b/compression/python/pytree/pytree_transforms.py deleted file mode 100644 index 7e065af..0000000 --- a/compression/python/pytree/pytree_transforms.py +++ /dev/null @@ -1,508 +0,0 @@ -"""Tools for transforming "nested python object" tree data structures. - -# Context - -The motivation for this module came from ML applications that ought to -be based on a principled handling of nested Python data structures. -Having such principled pytree-transforming code available solves -some other problems, such as doing away with a need to abuse -tree-mapping for-side-effect-only and having to use a hope-and-pray -approach to processing very deeply nested values which with a recursive -approach might trigger a RecursionError. - -We specifically want to cover the use case of having ML model -parameters that are available in a nested Python data structure for -which there "almost" is a unique-up-to-unique-isomorphism mapping from -and to this Algebraic Data Type: - -`data ModelParams a = Array a | Node [(String, ModelParams a)]` - -In this correspondence, `a` is some array-type (perhaps -`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the -data-processing code is effectively entirely agnostic to this, and a -`Node` is "almost" an associative-list of (key, value) pairs -representing a Python dict. (Note: The "almost" here is mostly about -the conceptual wart that assoc-lists can in principle have key -duplicates, but Python dicts can not. This is however not a problem -since all we need is the transformation in one direction, -i.e. whatever data-processing `f` we want to express on the -model-parameters-pytree, we can express by specifying a "faithful" -mapping `m` into the above algebraic data type through which every -such pytree data transform factorizes, i.e. for every `f` we can find -a `g` such that `f(p) = g(m(p))`.) - -## Components - -The main workhorse in this module is the `pytree_iter` function that -maps a "PyTree (such as representing `ModelParams`)" to an iterator -over values obtained by applying a mapping-function to the "key-path" -and leaf-value for every leaf, where the "key-path" contains a -linked-list representation of the reversed sequence of keys from the -tree-root, with list-nodes being represented by pairs -`(latest_dict_key, rest_path)`, and the empty path being represented -by `()`. - -For the sake of genericity, `pytree_iter` is built in such a way that -it actually can handle any kind of traversal of PyTree-trees that do -represent algebraic data types (note however that some some do not) - -but for this to make sense, the user must have a way to define how to -interpret tree-nodes, in particular identify leaves. This requires -providing a function `node_handler` with the same signature and -behavior as described below for "node handlers". - -Additionally, this module provides mapping-over-pytrees via -`pytree_map`, which is also built in such a way that it makes the -correspondence between an algebraic data type and its Python -nested-tree representation explicit. Despite being powerful and -flexible, this, however, may in general require a bit more effort to -wire up, since node-rebuilding can be fairly nontrivial. - -Furthermore, as a prominent application, this module provides a simple -deep-freezing function that translates a nested Python data structure -to deeply-immutable form. - -## Concepts and Conventions - -"revtuple representation": - - As we iterate over a tree, we will have to keep track of the - path-to-tree-root. Naturally, two sibling nodes `n1` and `n2` - will share the same parent-path (being siblings), so it makes - sense to use a linked-list-with-shared-tail representation. - Python does not have a natural notion for that, so we use - recursively-constructed tuples `(node_tag, parent_path)` - that represent the path-from-root in-reverse-order, i.e. - for a non-empty path `p`, `p[0]` is the node-tag at the - deepest nesting level. We call this a "revtuple representation" - of the path. - -"node handler": - - A node-handler classifies a tree-node as "leaf or other node", and - for non-leaf nodes provides information about both its children and - how to rebuild it. The behavior of a node-handler function must be - in alignment with this docstring: - - '''Processes a tree-node as required by pytree-iteration and -mapping. - - Args: - revtuple_path: Revtuple-representation of the path-from-root - to the current node. - node: a tree-node in a ML-model tree that is recursively - built out of leaf-values and other nodes. - - Returns: - `None` if the tree-node is to be regarded as a leaf, otherwise - a pair `(rebuilder, iterator)`, where `iterator` iterates - over the data-content of the node, each item represented as a pair - of `(lookup_path_item, value_item)`, and `rebuilder` is a function - which, when applied to an iterable of the aforementioned value-items - (or some transformation thereof) returns a node that is equivalent - to the original (or up to a transformation of the contents). - - Raises: - InvalidTreeNodeError: If the tree contains a node of a kind - that is not expected to show up. - ''' - - Examples: - - (The behavior of a node-handler is somewhat nontrivial, so covering - two very common cases via examples is in order.) - - This node-handler would allow descending into (nested) - instances of `list` (but not subclass instances thereof): - - ```def list_node_handler(revtuple_path, obj): - ''' ... ''' - if type(obj) is list: - return list, enumerate(obj) - else: - return None - ``` - - This node-handler would allow descending into (nested) mappings, - which upon rebuilding would get turned into `dict` instances: - - ```def mapping_node_handler(revtuple_path, obj): - ''' ... ''' - if isinstance(obj, collections.abc.Mapping): - # For generic mappings, we cannot rely on key- and item-iteration - # being guaranteed to use identical iteration-order. - items = list(obj.items()) - keys = [kv[0] for kv in items] - return (lambda values: dict(zip(keys, values))), items - else: - return None - ``` - - A dict/mapping node-handler can of course rename keys, add or remove - entries, make decisions based on the item-path, or map a dict to - an associative list, etc. - -## Further Design Notes - -The `pytree_map` function requests the leaf-transform and node-handler -to be side-effect-free functions. This is both required to leave -implementation-side flexibility, and also follows the general LISP -recommendation to not abuse mapping (which should be a pure -data-transformation) for imperative data processing. Overall, if -a need for more general "nested datastructures" processing becomes -pressing, it is for the better if this leads to a proper articulation -of the specific needs, to be addressed with appropriate design, rather -than abuse of functional data-transforms becoming "a bad idiom -that turned into established practice". - -""" - -import collections.abc -import immutabledict - -import numpy - -from typing import Any, Callable, Iterable, Iterator, TypeVar - - -T = TypeVar('T') -U = TypeVar('U') - -KT = TypeVar('KT') -NT = TypeVar('NT') - - -## Type of the reverse-order-keys-to-root path. -# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.) -RevTuplePath = tuple - -## Type of the `leaf_transform` function-argument used for tree-iteration. -# -# This would be the correct type we would have to specify here but cannot, -# since the design of Python's static typing at the time of this writing -# is too broken for that: -# -# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R] -# -# Instead, we have to settle for...: -LeafTransformFunc = Callable[[RevTuplePath, Any], Any] - - -## Type of the `tree_node_handler` function-argument used for -## tree-iteration and tree-mapping. -# -# Again, this is the correct type we would have to put here but cannot: -# -# type NodeHandlerFunc[KT] = ( -# Callable[[NT], -# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT], -# Iterator[tuple[KT, NT]]]]) -# -# ...so, we have to instead settle for: -NodeHandlerFunc = ( - Callable[[RevTuplePath, NT], - None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], - Iterator[tuple[Any, NT]]]]) - - -Predicate = Callable[[object], bool] - - -class InvalidTreeNodeError(ValueError): - """Encountered a tree-node of invalid type.""" - - -def linearize_revtuple_path( - revtuple_path: RevTuplePath, - present_as: Callable[[Iterator[T]], U] = tuple) -> U: - """Translates a revtuple path to (typically) linear form. - - With default `present_as`, this will map a path of the form - `(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple - (root, ..., key_{N-1}, key_{N}). - - Args: - revtuple_path: A linked-list-as-recursive-pairs - reverse-order tuple-representation of the path. - Path-root is `()`, and node-key `x` relative to - earlier path `p` is represented as `(x, p)`. - present_as: Callable that consumes an iterator over - path-pieces - with the deepest-nesting level coming last - - turning it into a linearized path. Defaults to `tuple`. - - Returns: - Linearized presentation of all the node-keys in the - recursive-path in order, deepest-down path component coming last. - """ - pieces = [] - todo = revtuple_path - while todo: - node, todo = todo - pieces.append(node) - return present_as(reversed(pieces)) - - -# This function itself has type `NodeHandlerFunc`, but Python does not -# allow us to here simply type-annotate it like this. We cannot even -# introduce an abbreviation for the complicated output-type, -# since that would have to be parametric in node-type `NT` (and `KT`). -def everything_is_a_leaf_node_handler( - revtuple_path: tuple, - node : NT) -> ( - None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], - Iterator[tuple[Any, NT]]]): - """Processes a tree-node as required by pytree-iteration and -mapping. - - Interface and signature are in alignment with the requirements for a - "node handler" function explained in the module-docstring. - - Args: - revtuple_path: the path-to-root for this node. - node: a tree-node. - - Returns: - `None`, i.e. classifying any kind of node as a leaf-node. - """ - del revtuple_path, node # Unused. - return None - - -def leaf_summary(path: RevTuplePath, x: object): - """Produces a human-readable summary-string for a leaf-node. - - Args: - path: revtuple representation of the path-to-root. - x: The leaf-value. - """ - del path # Ignored here. - tx = type(x) - mod = tx.__module__ - modname = mod if isinstance(mod, str) else mod.__name__ - type_str = f'{modname}.{tx.__qualname__}' - repr_str = repr(x) - repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...' - # On str, int, float, etc. `{type_str}(repr(x))` would actually still be - # a (non-literal) Python-expression that would evaluate to the original value. - # However, we make no promises beyond "human-readable". - return f'{type_str}({repr_abbrev})' - - -# With respect to static type annotations, the limitations of Python's -# approach to static typing really become prominently visible here. -# -# Different arguments have type-parameters, but since there is no way -# to have parametric abbreviations such as `LeafTransformFunc[L, R]`, -# the only way we would have available to express relations between -# type-parameters would be to substitute in the not-abbreviated form of -# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous. -# We instead here settle for "we cannot express that `tree` must -# have the same type as the input-type to `tree_node_handler` and use `Any`, -# and likewise for leaf_transform and the output. -def pytree_leaf_iter( - tree: Any, - leaf_transform: LeafTransformFunc, - node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler, - ) -> Iterator[Any]: - # ...actual return type would be `Iterator[{what leaf_transform returns}]`. - """Iterates over the leaves of a tree. - - Args: - tree: The tree to iterate over. - leaf_transform: A callable `f` that will get applied - as `f(revtuple_path, leaf)`, where `revtuple_path` - is the revtuple representation of the path to the - leaf from the root. - node_handler: A "node handler" (see module docstring) - that processes nodes encountered during iterative traversal. - - Yields: - Value of `leaf_transform(p, x)`, where `x` is the current leaf - and `p` is its revtuple-path to the root. - """ - # Note: Exit points for the code below are in non-obvious places - # and hence marked with " # ***EXIT***". - # - # Doing iteration properly is slightly nontrivial. - # One may be tempted to go for a very simple recursive implementation - # (with an extra pre-final `path` argument to `pytree_iter`): - # - # maybe_substructure = node_handler(path, tree) - # if maybe_substructure is None: - # # We are looking at a leaf-node. - # yield leaf_transform(path, tree) - # else: - # _, contents_iter = maybe_substructure - # for k, v in contents_iter: - # yield from pytree_iter(v, leaf_transform, (k, path), node_handler) - # - # That, however, would be flawed, since there is no a priori reason - # why a pytree may not be a very deeply nested structure - such as a - # long linked list. That would then risk raising `RecursionError`, - # and since Python by design(!) does not perform tail call elimination - # or any other kind of advanced CPS transforms, there is no recursive - # solution here. So, to do this properly, we have to do this iteratively. - # - # We are facing an annoying situation here: If `tree` itself is a leaf, - # we have two options: (a) wrapping it up in a one-node tree - # and processing that, or (b) special-casing "root is a leaf". - # Option (b) leads to some mild node-processing code-duplication - # for a single node (the root). - # Option (a) requires having special cases for node-processing that - # get looked at for every tree node. We go with option (b) here. - maybe_substructure = node_handler((), tree) - if maybe_substructure is None: - # The tree itself is a leaf. - yield leaf_transform((), tree) - return # ***EXIT*** - # Otherwise, we are looking at a tree. - _, contents_iter = maybe_substructure - current_revtuple_path = () - work_to_do = [contents_iter] - # Otherwise-unreachable sentinel for reliably identifying - # iterator-exhaustion without using exceptions: - sentinel = object() - while True: - current_iter = work_to_do[-1] - maybe_next_item = next(current_iter, sentinel) - if maybe_next_item is sentinel: - # We are done at this level. - work_to_do.pop() - if not work_to_do: return # ***EXIT*** - current_revtuple_path = current_revtuple_path[1] - else: - path_piece, subtree = maybe_next_item - extended_revtuple_path = (path_piece, current_revtuple_path) - maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree) - if maybe_subtree_substructure is None: # Case: subtree is a leaf. - yield leaf_transform(extended_revtuple_path, subtree) - else: # Case: subtree is a tree. - current_revtuple_path = (path_piece, current_revtuple_path) - _, items_iter = maybe_subtree_substructure - work_to_do.append(items_iter) - - -# The current design approach here would be appropriate for -# applying leaf-transforms while retaining the structure of the tree - -# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor. -# -# It is not entirely clear whether this is the abstraction that we should -# consider as being appropriately generic to flesh out explicitly - rather -# than starting from a more general approach of which this then is a special -# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme -# -# On the other hand, there is a lot of flexibility via whatever -# node-rebuilder a node-handler produces - this can do quite some reshaping -# of a tree, including dropping or duplicating nodes. -def pytree_map( - tree: Any, - leaf_transform, - node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler, - ): - """Maps a (potentially nested) Python value to another such value. - - Args: - tree: The Python-object to be mapped. - leaf_transform: A callable `f` that will get applied - as `f(revtuple_path, leaf)`, where `revtuple_path` - is the revtuple representation of the path to the - leaf from the root. Must be side effect free. - node_handler: A "node handler" (see module docstring) - that processes nodes encountered during iterative traversal. - Must be side effect free. - - Returns: - The outcome of translating `tree`. - """ - # Note: Exit points for the code below are in non-obvious places - # and hence marked with " # ***EXIT***". - # - # Otherwise-inaccessible sentinel object, for reliably identifying - # missing-values via identity-check against sentinel lookup-default. - sentinel = object() - # Code structure mostly follows pytree_leaf_iter. - maybe_substructure = node_handler((), tree) - if maybe_substructure is None: - return leaf_transform((), tree) # ***EXIT*** - rebuilder, items_iter = maybe_substructure - current_revtuple_path = () - # Per-level, we have a triplet of: - # (rebuilder, remaining_items_to_iterate_over, processed). - parts_for_assembly = [(rebuilder, items_iter, [])] - while True: - this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1] - maybe_next_item = next(this_items_iter, sentinel) - if maybe_next_item is sentinel: - # We are done with all the items for this level. - parts_for_assembly.pop() - built_iter = this_rebuilder(this_done_pieces) - if not parts_for_assembly: # No outer structure, so at-top-level. - return built_iter # ***EXIT*** - else: # We have outer structure. - parts_for_assembly[-1][-1].append(built_iter) - current_revtuple_path = current_revtuple_path[1] - continue # ...with next is-the-final-item-complete-check. - else: - # More constituents of the current item. - path_piece, subtree_item = maybe_next_item - extended_revtuple_path = (path_piece, current_revtuple_path) - maybe_subtree_substructure = node_handler( - extended_revtuple_path, - subtree_item) - if maybe_subtree_substructure is None: - this_done_pieces.append( - leaf_transform(extended_revtuple_path, subtree_item)) - else: - # We have a subtree. - subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure - current_revtuple_path = (path_piece, - current_revtuple_path) - parts_for_assembly.append( - (subtree_rebuilder, subtree_items_iter, [])) - - -def deep_freeze( - tree, - *, - is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping), - is_set : Predicate = lambda x: isinstance(x, collections.abc.Set), - is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)), - leaf_fn: Callable[[Any], Any] = lambda x: x, - ): - """Recursively freezes Set/Mapping/List/Tuple structures. - - Args: - tree: The potentially deeply-nested object to deep-freeze. - is_mapping: Callable that decides whether a sub-object is a mapping. - Defaults to an `isinstance()` check for `collections.abc.Mapping`. - is_set: Callable that decides whether a sub-object is a set. - Defaults to an `isinstance()` check for `collections.abc.Set`. - is_sequence: Callable that decides whether a sub-object is a sequence. - Defaults to a check for being a `tuple` or `list` instance. - leaf_fn: Function to use for translating non-mapping/set/sequence - instances. - - Returns: - Translated-to-deeply-immutable form of `tree`. - """ - idict = immutabledict.immutabledict - def freeze_node_handler(path, x): - if is_set(x): - return frozenset, ((None, y) for y in x) - if is_mapping(x): - # Mappings already have hashable, so - # (should-be-)deeply-immutable keys. - # Hence, we only need to deep-freeze the values. - # - # Note that non-`dict` mappings might not guarantee - # to respect iteration-order, so we have to be careful here: - items = list(x.items()) - keys = [kv[0] for kv in items] - values = [kv[1] for kv in items] - return ((lambda ys: idict(zip(keys, ys))), - iter(items)) - if is_sequence(x): - return tuple, enumerate(iter(x)) - # Otherwise, this should not be traversed. - return None - def leaf_transform(revtuple_path, value): - del revtuple_path # Unused. - return leaf_fn(value) - return pytree_map(tree, leaf_transform, freeze_node_handler) diff --git a/compression/python/pytree/pytree_transforms_test.py b/compression/python/pytree/pytree_transforms_test.py deleted file mode 100644 index fdaec71..0000000 --- a/compression/python/pytree/pytree_transforms_test.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Basic tests for 'algebraic data type based pytree' transformations.""" - - -import collections.abc -import sys -import unittest - -import pytree_transforms - - -def _get_deep_pytree(packaging_fn, bottom, depth): - current = bottom - for n in reversed(range(depth)): - current = packaging_fn(n, current) - return current - - -def _dict_node_handler(p, d): - del p # Unused. - if isinstance(d, dict): - keys = d.keys() - newdict = lambda vals: dict(zip(keys, vals)) - return (newdict, iter(d.items())) - else: - return None - - -class PyTreeTest(unittest.TestCase): - """Basic correctness validation tests for PyTree transformations.""" - - def test_linearize_revtuple_path(self): - """Tests guarantees given by `linearize_revtuple_path`.""" - linearize_revtuple_path = pytree_transforms.linearize_revtuple_path - with self.subTest(guarantee='empty'): - self.assertEqual(linearize_revtuple_path(()), ()) - with self.subTest(guarantee='typical'): - self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))), - (10, 20, 30)) - with self.subTest(guarantee='present_as'): - self.assertEqual( - linearize_revtuple_path( - (30, (20, (10, ()))), present_as=list), - [10, 20, 30]) - - def test_everything_is_a_leaf_node_handler(self): - """Tests guarantees given by `everything_is_a_leaf_node_handler`.""" - everything_is_a_leaf_node_handler = ( - pytree_transforms.everything_is_a_leaf_node_handler) - self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'), - None) - self.assertEqual(everything_is_a_leaf_node_handler(('b', ()), - dict(a=3)), - None) - - def test_leaf_summary(self): - """Tests guarantees given by `leaf_summary`.""" - # Since the docstring only guarantees "a human-readable presentation", - # we can and should only do loose checks. - thing = (5678, 9531) - summary = pytree_transforms.leaf_summary(('key', ()), thing) - self.assertIsInstance(summary, str) - self.assertIn(str(thing[0]), summary) - self.assertIn(str(thing[1]), summary) - - def test_pytree_leaf_iter(self): - """Tests guarantees given by `pytree_leaf_iter`.""" - pytree_leaf_iter = pytree_transforms.pytree_leaf_iter - def leaf_transform(path, leaf): - return repr(leaf) if path and path[0].startswith('R') else leaf - with self.subTest(guarantee='returns_iterator'): - result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler) - self.assertIsInstance(result, collections.abc.Iterator) - with self.subTest(guarantee='totally_empty'): - result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler)) - self.assertEqual(result, []) - with self.subTest(guarantee='no_leaves'): - result = list(pytree_leaf_iter(dict(a={}), - leaf_transform, _dict_node_handler)) - self.assertEqual(result, []) - with self.subTest(guarantee='is_leaf'): - result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler)) - self.assertEqual(result, [777]) - with self.subTest(guarantee='generic'): - result = list(pytree_leaf_iter( - dict(n0=dict(n01=dict(n012=1002, - n013=1003, - Rn014=1004, - ), - n02=1005), - n5=1006), - leaf_transform, _dict_node_handler)) - self.assertEqual(result, [1002, 1003, '1004', 1005, 1006]) - with self.subTest(guarantee='with_keys'): - result = list(pytree_leaf_iter( - dict(n0=dict(n01=dict(n012=1002, - n013=1003)), - n1=1004), - lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s), - _dict_node_handler)) - self.assertEqual(result, - [(('n0', 'n01', 'n012'), 1002), - (('n0', 'n01', 'n013'), 1003), - (('n1',), 1004)]) - - def test_pytree_map(self): - """Tests guarantees given by `pytree_map`.""" - pytree_map = pytree_transforms.pytree_map - leaf_transform = lambda p, s: repr(s) - tree1 = dict(t0=dict(t10=1001, - t11=dict(t110=1002, - t111=1003), - t12=dict(t120=1004, - t121=1005, - t122=1006)), - t1=1007) - with self.subTest(guarantee='no_leaves'): - result = pytree_map(dict(a={}), - leaf_transform, - _dict_node_handler) - self.assertEqual(result, dict(a={})) - with self.subTest(guarantee='is_leaf'): - result = pytree_map(777, leaf_transform, _dict_node_handler) - self.assertEqual(result, '777') - with self.subTest(guarantee='generic'): - result = pytree_map(tree1, leaf_transform, _dict_node_handler) - self.assertEqual(result['t0']['t10'], '1001') - - def test_deeply_nested(self): - """Tests correct behavior on deeply-nested data structures.""" - pytree_leaf_iter = pytree_transforms.pytree_leaf_iter - pytree_map = pytree_transforms.pytree_map - # - depth = max(10**5, sys.getrecursionlimit() + 100) - deep_tree = _get_deep_pytree(lambda n, t: {n: t}, - 'leaf', depth) - with self.subTest(function='pytree_leaf_iter'): - leaves = list(pytree_leaf_iter(deep_tree, - lambda p, s: s.upper(), - _dict_node_handler)) - self.assertEqual(leaves, ['LEAF']) - with self.subTest(function='pytree_map'): - mapped_deep_tree = pytree_map(deep_tree, - lambda p, s: s, - _dict_node_handler) - self.assertIsInstance(mapped_deep_tree, dict) - with self.subTest(function='combined'): - leaves = list( - pytree_leaf_iter( - pytree_map(deep_tree, - lambda p, s: s.capitalize(), - _dict_node_handler), - lambda p, s: s + s, - _dict_node_handler)) - self.assertEqual(leaves, ['LeafLeaf']) - - def test_deep_freeze(self): - """Tests guarantees given by `deep_freeze`.""" - frozen = pytree_transforms.deep_freeze( - dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))])) - self.assertIsInstance(frozen, collections.abc.Mapping) - self.assertNotIsInstance(frozen, collections.abc.MutableMapping) - self.assertIsInstance(frozen['a'], tuple) - # `frozen` is hashable, and hashes to an integer. - self.assertIsInstance(hash(frozen), int) - - -if __name__ == '__main__': - unittest.main() diff --git a/compression/python/pytree/requirements.txt b/compression/python/pytree/requirements.txt deleted file mode 100644 index 90c3f39..0000000 --- a/compression/python/pytree/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -immutabledict>=4.2.0 -numpy>=1.26.4 -orbax-checkpoint>=0.0.0 - diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 26313c1..04eb20a 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -158,9 +158,6 @@ TEST_F(GemmaTest, CrossEntropySmall) { float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-token entropy: %f\n", entropy); switch (config.model) { - case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 2.61f, 0.02f); - break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 1.14f, 0.02f); break; diff --git a/gemma/activations.h b/gemma/activations.h index 63b3153..67e1eba 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -31,32 +31,6 @@ namespace gcpp { -struct GriffinActivations { - GriffinActivations(const ModelConfig& config, size_t batch_size, - const Allocator& allocator) - : griffin_x( - MatFactory("griffin_x", batch_size, config.model_dim, allocator)), - griffin_y( - MatFactory("griffin_y", batch_size, config.model_dim, allocator)), - griffin_gate_x(MatFactory("griffin_gate_x", batch_size, - config.model_dim, allocator)), - griffin_multiplier(MatFactory("griffin_mul", batch_size, - config.model_dim, allocator)) {} - - void SetBatchSize(size_t batch_size) { - if (griffin_x.Rows() == 0) return; - griffin_x.OverrideRows(batch_size); - griffin_y.OverrideRows(batch_size); - griffin_gate_x.OverrideRows(batch_size); - griffin_multiplier.OverrideRows(batch_size); - } - - MatStorageT griffin_x; - MatStorageT griffin_y; - MatStorageT griffin_gate_x; - MatStorageT griffin_multiplier; -}; - struct AttentionActivations { // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. @@ -143,7 +117,7 @@ struct AttentionActivations { MatStorageT inv_timescale_global; hwy::Divisor div_seq_len; - // Unfortunately, some models (Griffin) have non-power-of-two heads. + // Unfortunately, some models have had non-power-of-two heads. hwy::Divisor div_heads; float query_scale; }; @@ -169,9 +143,7 @@ struct Activations { MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), attention(config, layer_config, batch_size, seq_len, ctx.allocator, - row_ptrs), - griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, - ctx.allocator) { + row_ptrs) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -199,7 +171,6 @@ struct Activations { ffw_out.OverrideRows(batch_size); attention.SetBatchSize(batch_size); - griffin.SetBatchSize(batch_size); } const LayerConfig& layer_config; @@ -215,7 +186,6 @@ struct Activations { MatStorageT ffw_out; AttentionActivations attention; - GriffinActivations griffin; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc index 8afd561..31ed4d1 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -327,6 +327,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.SumHeads"); const LayerConfig& layer_config = layer.layer_config; + (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length // layer_config.qkv_dim. Thus the [num_interleaved, // layer_config.model_dim] matmul output is the sum over heads. Compare @@ -334,10 +335,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, // encoded) HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && layer_config.qkv_dim != 0); - const float* add = layer_config.softmax_attn_output_biases - ? layer.attention_output_biases.PackedScale1() - : nullptr; - CallMatMul(activations.att_out, layer.att_weights, add, env, + CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, activations.att_sums); } diff --git a/gemma/configs.cc b/gemma/configs.cc index f19d30d..8856203 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -133,78 +133,6 @@ static ModelConfig ConfigGemma2_2B() { return config; } -static LayerConfig LayerConfigGemmaTiny(size_t model_dim) { - LayerConfig config; - config.model_dim = model_dim; - config.ff_hidden_dim = 256; - config.heads = 4; - config.kv_heads = 1; - config.qkv_dim = 16; - return config; -} - -static ModelConfig ConfigGemmaTiny() { - ModelConfig config = ConfigNoSSM(); - config.display_name = "GemmaTiny"; - config.model = Model::GEMMA_TINY; - config.wrapping = PromptWrapping::GEMMA_IT; - config.model_dim = 32; - config.vocab_size = 32; // at least two f32 vectors - config.max_seq_len = 32; - LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); - config.num_layers = 2; - config.layer_configs = {config.num_layers, layer_config}; - config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); - config.att_cap = 50.0f; - config.final_cap = 30.0f; - config.eos_id = 11; - config.secondary_eos_id = 11; - return config; -} - -static LayerConfig LayerConfigGriffin2B(size_t model_dim) { - LayerConfig config; - config.model_dim = model_dim; - config.griffin_dim = model_dim; - config.ff_hidden_dim = 7680; - config.heads = 10; - config.kv_heads = 1; - config.qkv_dim = 256; - config.conv1d_width = 4; - HWY_DASSERT(config.conv1d_width <= kMaxConv1DWidth); - config.ff_biases = true; - config.softmax_attn_output_biases = true; - config.optimized_gating = false; - config.type = LayerAttentionType::kGriffinRecurrentBlock; - config.activation = ActivationType::Gelu; - config.post_qk = PostQKType::HalfRope; - return config; -} - -static ModelConfig ConfigGriffin2B() { - ModelConfig config = ConfigNoSSM(); - config.display_name = "Griffin2B"; - config.model = Model::GRIFFIN_2B; - // Griffin uses local attention, so max_seq_len is actually the local - // attention window. - config.model_dim = 2560; - config.vocab_size = kVocabSize; - config.max_seq_len = 2048; - LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim); - config.num_layers = 26; - config.layer_configs = {config.num_layers, layer_config}; - for (size_t i = 2; i < config.num_layers; i += 3) { - config.layer_configs[i].type = LayerAttentionType::kGemma; - config.layer_configs[i].griffin_dim = 0; - } - config.attention_window_sizes = - FixedAttentionWindowSizes<26>(config.max_seq_len); - config.use_local_attention = true; - config.final_cap = 0.0f; - return config; -} - static LayerConfig LayerConfigVit(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; @@ -510,10 +438,6 @@ static ModelConfig ConfigFromModel(Model model) { return ConfigGemma2_9B(); case Model::GEMMA2_27B: return ConfigGemma2_27B(); - case Model::GRIFFIN_2B: - return ConfigGriffin2B(); - case Model::GEMMA_TINY: - return ConfigGemmaTiny(); case Model::PALIGEMMA2_3B_224: return ConfigPaliGemma2_3B_224(); case Model::PALIGEMMA2_3B_448: @@ -547,10 +471,6 @@ const char* ModelPrefix(Model model) { return "9b"; case Model::GEMMA2_27B: return "27b"; - case Model::GRIFFIN_2B: - return "gr2b"; - case Model::GEMMA_TINY: - return "tiny"; case Model::PALIGEMMA2_3B_224: return "paligemma2-3b-224"; case Model::PALIGEMMA2_3B_448: @@ -750,13 +670,10 @@ bool ModelConfig::OverwriteWithCanonical() { Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { switch (layers) { - case 2: - return Model::GEMMA_TINY; case 18: return Model::GEMMA3_270M; case 26: - if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B; if (layer_types & kDeducedViT) return Model::GEMMA3_1B; return Model::GEMMA2_2B; case 27: diff --git a/gemma/configs.h b/gemma/configs.h index 0c93e30..e4a26b8 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -68,14 +68,11 @@ static inline bool EnumValid(PromptWrapping wrapping) { enum class LayerAttentionType { kGemma, - kGriffinRecurrentBlock, kVit, }; static inline bool EnumValid(LayerAttentionType type) { - return type == LayerAttentionType::kGemma || - type == LayerAttentionType::kGriffinRecurrentBlock || - type == LayerAttentionType::kVit; + return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; } // Post attention and ffw normalization type. @@ -163,9 +160,8 @@ enum class Model { // 1 and 2 are obsolete. GEMMA2_9B = 3, GEMMA2_27B, - GRIFFIN_2B, - GEMMA_TINY, // for testing only - GEMMA2_2B, + // 5 and 6 are obsolete. + GEMMA2_2B = 7, // 8 and 9 are obsolete. PALIGEMMA2_3B_224 = 10, PALIGEMMA2_3B_448, @@ -199,13 +195,19 @@ static inline bool IsPaliGemma(Model model) { return false; } +static inline bool IsObsolete(Model model) { + const size_t i = static_cast(model); + if (i == 5 || i == 6 || i == 8 || i == 9) return true; + return false; +} + // Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`. template void ForEachModel(const Func& func) { for (size_t i = static_cast(Model::GEMMA2_9B); i < static_cast(Model::kSentinel); ++i) { - if (i == 8 || i == 9) continue; - func(static_cast(i)); + const Model model = static_cast(i); + if (!IsObsolete(model)) func(model); } } @@ -214,7 +216,7 @@ static inline bool EnumValid(Model model) { if (model == Model::UNKNOWN) return true; const size_t i = static_cast(model); if (i >= static_cast(Model::GEMMA2_9B) && - i < static_cast(Model::kSentinel) && i != 8 && i != 9) { + i < static_cast(Model::kSentinel) && !IsObsolete(model)) { return true; } return false; @@ -235,15 +237,20 @@ struct LayerConfig : public IFields { // Source of truth for field ordering. void VisitFields(IFieldsVisitor& visitor) override { + // Formerly used for Griffin. + uint32_t unused_griffin_dim = 0; + uint32_t unused_conv1d_width = 0; + bool unused_softmax_attn_output_biases = false; + visitor(model_dim); - visitor(griffin_dim); + visitor(unused_griffin_dim); visitor(ff_hidden_dim); visitor(heads); visitor(kv_heads); visitor(qkv_dim); - visitor(conv1d_width); + visitor(unused_conv1d_width); visitor(ff_biases); - visitor(softmax_attn_output_biases); + visitor(unused_softmax_attn_output_biases); visitor(optimized_gating); visitor(post_norm); visitor(type); @@ -263,14 +270,11 @@ struct LayerConfig : public IFields { bool IsMHA() const { return heads == kv_heads; } uint32_t model_dim = 0; - uint32_t griffin_dim = 0; uint32_t ff_hidden_dim = 0; uint32_t heads = 0; uint32_t kv_heads = 0; - uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). - uint32_t conv1d_width = 0; // Griffin only + uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). bool ff_biases = false; - bool softmax_attn_output_biases = false; // for Griffin bool optimized_gating = true; // for Gemma3 PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; @@ -358,7 +362,8 @@ struct ModelConfig : public IFields { visitor(final_cap); visitor(absolute_pe); - visitor(use_local_attention); + bool unused_use_local_attention = false; // formerly used for Griffin + visitor(unused_use_local_attention); visitor(query_scale); visitor(layer_configs); visitor(attention_window_sizes); @@ -454,7 +459,6 @@ struct ModelConfig : public IFields { float final_cap = 0.0f; bool absolute_pe = false; - bool use_local_attention = false; // Griffin only QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; std::vector attention_window_sizes; @@ -478,7 +482,6 @@ struct ModelConfig : public IFields { ModelConfig GetVitConfig(const ModelConfig& config); enum DeducedLayerTypes { - kDeducedGriffin = 1, kDeducedViT = 2, kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. }; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 0177c92..62288ff 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -34,7 +34,6 @@ // After highway.h #include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" -#include "gemma/griffin.h" // includes highway.h #include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE @@ -77,14 +76,6 @@ void Attention(LayerAttentionType type, const size_t num_tokens, GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, env, /*flags=*/0); - } else { - HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); - // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, - // so map `layer` to the Griffin layer index. - const size_t griffin_layer = - activations.attention.config.NumLayersOfTypeBefore(type, layer_idx); - GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch, - env); } } @@ -484,13 +475,6 @@ static void GenerateT(const ModelConfig& config, const AesCtrEngine& engine, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { - // Griffin assumes that the recurrent block cache is zero-initialized. - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - if (qbatch.MutablePos(qi) == 0) { - qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models. - } - } - size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; size_t total_prefill_tokens = 0; // only for throughput stats. diff --git a/gemma/griffin.cc b/gemma/griffin.cc deleted file mode 100644 index 35bf29a..0000000 --- a/gemma/griffin.cc +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2025 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "compression/types.h" // GEMMA_DISABLED_TARGETS -#ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS -#endif // HWY_DISABLED_TARGETS - -#include "gemma/activations.h" -#include "gemma/gemma.h" -#include "gemma/gemma_args.h" -#include "gemma/weights.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -// After highway.h -#include "ops/matvec-inl.h" -#include "ops/ops-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, - const LayerWeightsPtrs* layer_weights, - Activations& activations, QBatch& qbatch, - MatMulEnv& env) { - PROFILER_ZONE("Gen.Griffin"); - hwy::ThreadPool& pool = env.ctx.pools.Pool(0); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D df; - - const size_t model_dim = layer_weights->layer_config.model_dim; - HWY_DASSERT(model_dim % hn::Lanes(df) == 0); - - const size_t heads = layer_weights->layer_config.heads; - const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; - HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); - const size_t kHeadDim = model_dim / heads; - const size_t kMatrixSize = kHeadDim * kHeadDim; - - const size_t num_interleaved = num_tokens * qbatch.Size(); - const hwy::Divisor div_qbatch(static_cast(qbatch.Size())); - GriffinActivations& griffin = activations.griffin; - - // X / Y linear layers. - // TODO: MatMul - HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows()); - HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows()); - CallUpcastedSame( - &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, - [&](const auto* wx, const auto* wy) { - for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT y = griffin.griffin_y.Row(r); - float* HWY_RESTRICT x = griffin.griffin_x.Row(r); - TwoMatVecAdd( - *wx, *wy, 0, model_dim, model_dim, - activations.attention.pre_att_rms_out.Row(r), - /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), - /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), - /*out0=*/x, /*out1=*/y, pool); - Gelu(y, model_dim); - } - }); - - // Conv1D. - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); - - // cache[i] = input at time t-i. - float* HWY_RESTRICT cache[kMaxConv1DWidth]; - cache[0] = x; - for (size_t i = 1; i < conv_1d_width; i++) { - cache[i] = - qbatch.KV(qi).conv1d_cache.Row(griffin_layer) + - ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; - } - for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { - auto xv = hn::Load(df, x + i); - auto accum0 = - hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); - auto accum1 = hn::Zero(df); - for (size_t l = 0; 2 * l < conv_1d_width; l++) { - auto wv0 = - hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + - (conv_1d_width - 1 - 2 * l) * model_dim + i); - auto wv1 = - hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + - (conv_1d_width - 2 - 2 * l) * model_dim + i); - accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); - accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); - } - hn::Store(hn::Add(accum0, accum1), df, x + i); - hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i); - } - } - - // RGLRU - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t pos = qbatch.Pos(qi) + batch_idx; - - float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); - float* HWY_RESTRICT y = griffin.griffin_y.Row(qi); - float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi); - float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi); - float* HWY_RESTRICT rnn_state = - qbatch.KV(qi).rglru_cache.Row(griffin_layer); - - pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - size_t head_offset = head * kHeadDim; - CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) { - TwoOfsMatVecAddLoop( - *gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, - kHeadDim, x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + - head_offset, - /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + - model_dim + head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); - }); - Sigmoid(gate_x + head_offset, kHeadDim); - Sigmoid(a + head_offset, kHeadDim); - const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; - hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.PackedScale1() + head_offset, - fn_mul); - hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, - fn_mul); - // RNN scan - HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { - auto log_a = hn::Load(df, a + head_offset + i); - auto gated_x = hn::Load(df, x + head_offset + i); - auto rnn = hn::Load(df, rnn_state + head_offset + i); - auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); - if (pos == 0) { - x_multiplier = hn::Set(df, 1.0f); - } - auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); - hn::Store(new_x, df, rnn_state + head_offset + i); - - // Join branches. - auto yv = hn::Load(df, y + head_offset + i); - auto pre_out = hn::Mul(yv, new_x); - hn::Store(pre_out, df, x + head_offset + i); - } - }); - } // interleaved_idx - - // Final linear layer. - CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w, - layer_weights->griffin.linear_out_biases.PackedScale1(), env, - activations.attention.att_sums); -} // GriffinRecurrent - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); diff --git a/gemma/griffin.h b/gemma/griffin.h deleted file mode 100644 index 0ba6a23..0000000 --- a/gemma/griffin.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2025 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ - -// Declares GriffinRecurrent for all SIMD targets. - -#include - -#include "gemma/gemma.h" -#include "hwy/highway.h" - -namespace gcpp { - -// Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \ - const LayerWeightsPtrs* layer_weights, \ - Activations& activations, QBatch& qbatch, \ - MatMulEnv& env); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ - } // namespace NAMESPACE - -// Function declarations for each SIMD target. Allows direct call from the -// per-target namespace. We may later replace this with dynamic dispatch if -// the overhead is acceptable. -HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN) - -#undef GEMMA_DECL_GRIFFIN - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 9d107e8..ca814f4 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -24,26 +24,6 @@ namespace gcpp { -void KVCache::ZeroGriffinCache() { - if (conv1d_cache.Rows() == 0) return; - ZeroInit(conv1d_cache); - ZeroInit(rglru_cache); -} - -static size_t GriffinLayers(const ModelConfig& config) { - return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock); -} - -static size_t GriffinConv1dCols(const ModelConfig& config) { - size_t conv1d_width = 0; - for (const auto& layer_config : config.layer_configs) { - conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); - } - // The row offset, in blocks of model_dim is computed mod (conv1d_width - 1), - // hence allocate conv1d_width * model_dim total columns. - return conv1d_width * config.model_dim; -} - // Number of rows for KV cache. Note that both rows and cols are u32, and // the total number of elements can exceed 2^32. static size_t CappedSeqLen(const ModelConfig& config, @@ -56,30 +36,18 @@ static size_t CappedSeqLen(const ModelConfig& config, return inference_args.seq_len; } -KVCache::KVCache(const Extents2D& conv1d_extents, - const Extents2D& rglru_extents, const Extents2D& kv_extents, - const Allocator& allocator) - : conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd), - rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd), - kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), +KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) + : kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), allocator_(allocator) {} KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator) : KVCache( - Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), - Extents2D(GriffinLayers(config), config.model_dim), Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), allocator) {} KVCache KVCache::Copy() { - KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(), - kv_cache.Extents(), allocator_); - - if (conv1d_cache.Rows() != 0) { - CopyMat(conv1d_cache, copy.conv1d_cache); - CopyMat(rglru_cache, copy.rglru_cache); - } + KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 7b5b88d..31e964b 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -35,24 +35,15 @@ struct KVCache { // copy ctor to make the cost explicit. KVCache Copy(); - // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache - // and rglru_cache. - void ZeroGriffinCache(); - size_t SeqLen() const { return kv_cache.Rows(); } - // [griffin_layers, griffin_conv1d_cols * model_dim] - MatStorageT conv1d_cache; - MatStorageT rglru_cache; // [griffin_layers, model_dim] - MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] private: const Allocator& allocator_; // For use by other ctor and Copy() - KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents, - const Extents2D& kv_extents, const Allocator& allocator); + KVCache(const Extents2D& kv_extents, const Allocator& allocator); }; } // namespace gcpp diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 2aab1f5..a20caf2 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -221,9 +221,6 @@ static int DeduceLayerTypes(const BlobReader& reader) { int layer_types = 0; for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) { const std::string& key = reader.Keys()[key_idx]; - if (key.find("gr_conv_w") != std::string::npos) { // NOLINT - return kDeducedGriffin; - } if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT layer_types |= kDeducedViT; } @@ -293,7 +290,7 @@ static std::vector ReadScales(BlobReader& reader, const ModelConfig& config) { std::vector scales; // Check first to prevent `CallWithSpan` from printing a warning. This blob is - // optional even in pre-2025 format; Griffin was the first to include it. + // optional even in pre-2025 format. if (reader.Find(kDecoratedScalesName)) { HWY_ASSERT(reader.CallWithSpan( kDecoratedScalesName, diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc index de93cf9..05f829b 100644 --- a/gemma/tensor_info.cc +++ b/gemma/tensor_info.cc @@ -277,122 +277,6 @@ void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config, }); } -void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config, - const size_t layer_idx) { - const std::string suffix = LayerSuffix(layer_idx); - Add(suffix, { - .base_name = "gr_lin_x_w", - .source_names = {"recurrent_block/linear_x/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }); - Add(suffix, { - .base_name = "gr_lin_x_b", - .source_names = {"recurrent_block/linear_x/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_lin_y_w", - .source_names = {"recurrent_block/linear_y/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }); - Add(suffix, { - .base_name = "gr_lin_y_b", - .source_names = {"recurrent_block/linear_y/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_lin_out_w", - .source_names = {"recurrent_block/linear_out/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }); - Add(suffix, { - .base_name = "gr_lin_out_b", - .source_names = {"recurrent_block/linear_out/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, - { - .base_name = "gr_conv_w", - .source_names = {"recurrent_block/conv_1d/w"}, - .axes = {0, 1}, - .shape = {layer_config.conv1d_width, layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_conv_b", - .source_names = {"recurrent_block/conv_1d/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr1_gate_w", - .source_names = {"recurrent_block/rg_lru/input_gate/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - .concat_names = {"gr_gate_w", "gr2_gate_w"}, - }); - Add(suffix, { - .base_name = "gr2_gate_w", - .source_names = {"recurrent_block/rg_lru/a_gate/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - .concat_names = {""}, - }); - Add(suffix, { - .base_name = "gr_gate_w", - .source_names = {"recurrent_block/rg_lru/gate/w"}, - .axes = {0, 2, 1}, - .shape = {2 * layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - }); - Add(suffix, { - .base_name = "gr1_gate_b", - .source_names = {"recurrent_block/rg_lru/input_gate/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .concat_names = {"gr_gate_b", "gr2_gate_b"}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr2_gate_b", - .source_names = {"recurrent_block/rg_lru/a_gate/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .concat_names = {""}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_gate_b", - .source_names = {"recurrent_block/rg_lru/input_gate/b"}, - .axes = {0, 1}, - .shape = {2 * layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_a", - .source_names = {"recurrent_block/rg_lru/a_param"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - .scaled_softplus = true, - }); -} - void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, const size_t layer_idx) { @@ -553,10 +437,6 @@ void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config, .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, .cols_take_extra_dims = true, }); - - if (config.model == Model::GRIFFIN_2B) { - AddGriffinLayerTensors(layer_config, layer_idx); - } } TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) { diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h index c8252a4..d2b25d9 100644 --- a/gemma/tensor_info.h +++ b/gemma/tensor_info.h @@ -124,8 +124,6 @@ class TensorInfoRegistry { void AddModelTensors(const ModelConfig& config); void AddLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, size_t layer_idx); - void AddGriffinLayerTensors(const LayerConfig& layer_config, - size_t layer_idx); void AddImageLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, diff --git a/gemma/weights.cc b/gemma/weights.cc index 3d1d43e..8191bd9 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -88,7 +88,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners, // For FFN. Fast, only updates pointers. void LayerWeightsPtrs::SplitW1() { - // Used for Gemma and Griffin layers; FFWVit uses different tensors. + // Used for Gemma layers; FFWVit uses different tensors. if (layer_config.type == LayerAttentionType::kVit) return; // Files have both or neither of w1 and w2. diff --git a/gemma/weights.h b/gemma/weights.h index de3652a..06c0186 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -57,8 +57,7 @@ struct TensorArgs { // the _w1/_w2 tensors are not always present. kMaybeRead = 1, - // Avoid padding tensor rows when reading. Used for some Griffin tensors - // whose index computations do not use Row() accessors. + // Avoid padding tensor rows when reading. kPacked = 2, }; const int flags; @@ -102,17 +101,6 @@ struct LayerWeightsPtrs { qkv_einsum_w1(finder_("qkv1_w")), qkv_einsum_w2(finder_("qkv2_w")), attention_output_biases(finder_("attn_ob")), - griffin({.linear_x_w = finder_("gr_lin_x_w"), - .linear_x_biases = finder_("gr_lin_x_b"), - .linear_y_w = finder_("gr_lin_y_w"), - .linear_y_biases = finder_("gr_lin_y_b"), - .linear_out_w = finder_("gr_lin_out_w"), - .linear_out_biases = finder_("gr_lin_out_b"), - .conv_w = finder_("gr_conv_w"), - .conv_biases = finder_("gr_conv_b"), - .gate_w = finder_("gr_gate_w"), - .gate_biases = finder_("gr_gate_b"), - .a = finder_("gr_a")}), // MultiHeadDotProductAttention. vit({.attn_out_w = finder_("attn_out_w"), .attn_out_b = finder_("attn_out_b"), @@ -156,20 +144,6 @@ struct LayerWeightsPtrs { MatPtr qkv_einsum_w2; MatPtrT attention_output_biases; - struct { - MatPtr linear_x_w; - MatPtrT linear_x_biases; - MatPtr linear_y_w; - MatPtrT linear_y_biases; - MatPtr linear_out_w; - MatPtrT linear_out_biases; - MatPtrT conv_w; - MatPtrT conv_biases; - MatPtr gate_w; - MatPtrT gate_biases; - MatPtrT a; - } griffin; - struct { // MultiHeadDotProductAttention. MatPtr attn_out_w; // at least BF16. @@ -244,20 +218,6 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); - } else { - func(TENSOR_ARGS(griffin.linear_x_w, kMustRead)); - func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead)); - func(TENSOR_ARGS(griffin.linear_y_w, kMustRead)); - func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); - func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); - func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); - // conv_w and gate_w are not accessed via Row(), hence must not be padded. - // Note that *biases are 1D, hence packing/padding does not matter. - func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked)); - func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); - func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked)); - func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); - func(TENSOR_ARGS(griffin.a, kMustRead)); } { func(TENSOR_ARGS(gating_einsum_w, kMaybeRead)); @@ -281,11 +241,6 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(ffw_gating_biases, kMustRead)); func(TENSOR_ARGS(ffw_output_biases, kMustRead)); } - - if (layer_config.softmax_attn_output_biases && - layer_config.type == LayerAttentionType::kGemma) { - func(TENSOR_ARGS(attention_output_biases, kMustRead)); - } } // `ForEachTensor` // Zero-initializes all allocated tensors in the layer. diff --git a/python/configs.cc b/python/configs.cc index f8121bf..086c691 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -57,8 +57,6 @@ PYBIND11_MODULE(configs, py_module) { enum_(py_module, "LayerAttentionType") .value("kGemma", LayerAttentionType::kGemma) - .value("kGriffinRecurrentBlock", - LayerAttentionType::kGriffinRecurrentBlock) .value("kVit", LayerAttentionType::kVit); enum_(py_module, "PostNormType") @@ -84,8 +82,6 @@ PYBIND11_MODULE(configs, py_module) { .value("UNKNOWN", Model::UNKNOWN) .value("GEMMA2_9B", Model::GEMMA2_9B) .value("GEMMA2_27B", Model::GEMMA2_27B) - .value("GRIFFIN_2B", Model::GRIFFIN_2B) - .value("GEMMA_TINY", Model::GEMMA_TINY) .value("GEMMA2_2B", Model::GEMMA2_2B) .value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224) .value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224) @@ -121,15 +117,11 @@ PYBIND11_MODULE(configs, py_module) { class_(py_module, "LayerConfig") .def(init()) .def_readwrite("model_dim", &LayerConfig::model_dim) - .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) .def_readwrite("heads", &LayerConfig::heads) .def_readwrite("kv_heads", &LayerConfig::kv_heads) .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) - .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) .def_readwrite("ff_biases", &LayerConfig::ff_biases) - .def_readwrite("softmax_attn_output_biases", - &LayerConfig::softmax_attn_output_biases) .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) .def_readwrite("post_norm", &LayerConfig::post_norm) .def_readwrite("type", &LayerConfig::type) @@ -166,7 +158,6 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("att_cap", &ModelConfig::att_cap) .def_readwrite("final_cap", &ModelConfig::final_cap) .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) - .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) .def_readwrite("query_scale", &ModelConfig::query_scale) .def_readwrite("layer_configs", &ModelConfig::layer_configs) .def_readwrite("attention_window_sizes", From ad7d7a2713f435e0a51645ff195dff9c7c5fb958 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 5 Sep 2025 05:49:35 -0700 Subject: [PATCH 30/65] Further adjust dot_test threshold (numerics) PiperOrigin-RevId: 803428406 --- ops/dot_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 3cb565c..d93b210 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -851,8 +851,8 @@ class DotStats { ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f); // Naive and OnlyTwoProd are considerably higher than others - ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 3080.f); - ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 3080.f); + ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 1.4E4f); + ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 1.4E4f); // Kahan (FastTwoSum) is not much better here! ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f); From cbe24eac51c089f6a4f126f62683489c0fe794f0 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 5 Sep 2025 07:23:33 -0700 Subject: [PATCH 31/65] 1.15x speedup: parallel sampling, enabled by new RNG Also pass pos to SampleFunc, for seeding the RNG. PiperOrigin-RevId: 803453518 --- evals/cross_entropy.cc | 12 +++--- gemma/activations.h | 3 ++ gemma/gemma.cc | 85 +++++++++++++++++++++++++++--------------- gemma/gemma_args.h | 9 +++-- 4 files changed, 68 insertions(+), 41 deletions(-) diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index b7abb10..49acb50 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -99,14 +99,15 @@ HWY_EXPORT(CallSoftmax); float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, MatMulEnv& env, int verbosity) { - const StreamFunc stream_token = [](int, float) { return true; }; + const BatchStreamFunc stream_token = [](size_t, size_t, int, float) { + return true; + }; const int vocab_size = gemma.Config().vocab_size; float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) - size_t pos = 1; - const SampleFunc sample_token = [&](size_t qi, - Logits logits) -> TokenAndProb { + const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits, + size_t /*worker*/) -> TokenAndProb { // input is logits, not yet probabilities HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); // We are called for each token, but pos starts at 1. Clamping @@ -128,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, cross_entropy / std::log(2.0) / (pos + 1)); } - ++pos; return TokenAndProb{.token = token, .prob = prob}; }; @@ -138,7 +138,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, .verbosity = verbosity, - .stream_token = stream_token, + .batch_stream_token = stream_token, .sample_func = sample_token, }; TimingInfo timing_info; diff --git a/gemma/activations.h b/gemma/activations.h index 67e1eba..21e5e58 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -132,6 +132,7 @@ struct Activations { x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), logits( MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), + sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, config.model_dim, ctx.allocator)), @@ -164,6 +165,7 @@ struct Activations { x.OverrideRows(batch_size); x_bf.OverrideRows(batch_size); logits.OverrideRows(batch_size); + sampled.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size); C1.OverrideRows(batch_size); @@ -178,6 +180,7 @@ struct Activations { MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; + MatStorageT sampled; // batch_size x 3 (padded) // Gated FFW MatStorageT pre_ffw_rms_out; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 62288ff..785bd87 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -377,14 +377,13 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent // `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the // query is at the end of its sequence. -static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, - const ModelConfig& config, +static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token, + const float prob, const ModelConfig& config, const RuntimeConfig& runtime_config, - QBatch& qbatch, bool pos_plus_1, bool update_pos, + QBatch& qbatch, bool update_pos, hwy::BitSet4096<>& non_eos) { HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. - const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0); if (HWY_UNLIKELY( !runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) { // User decided to stop: set token to primary EOS to trigger IsEOS below. @@ -402,11 +401,13 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, // Must be called after Transformer: either after prefill, or during decode. // Computes logits, samples and streams the token. -static void SampleAndStream( - const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, const SampleFunc& sample_token, - Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env, - hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { +static void SampleAndStream(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + const SampleFunc& sample_token, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, hwy::BitSet4096<>& non_eos, + TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf, @@ -429,16 +430,33 @@ static void SampleAndStream( timing_info.NotifyGenerated(non_eos.Count()); - // TODO: parallelize - non_eos.Foreach([&](size_t qi) { - const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi)); + ParallelFor( + ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, + /*cluster_idx=*/0, [&](size_t qi, size_t worker) { + if (!non_eos.Get(qi)) return; - // We streamed all prefill tokens, but pos is still one behind because we - // started generation at pos = prompt.size() - 1. We want the pos argument - // to match the number of calls to `StreamToken`, as expected by the caller. - const bool pos_plus_1 = true; - StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, - pos_plus_1, update_pos, non_eos); + // We streamed all prefill tokens, but pos is still one behind + // because we started generation at pos = prompt.size() - 1. + // We want the pos argument to match the number of calls to + // `StreamToken`, as expected by the caller. + const size_t pos = qbatch.Pos(qi) + 1; + + const TokenAndProb tp = + sample_token(qi, pos, activations.logits.RowSpan(qi), worker); + // `sampled` is padded, which prevents false sharing. + activations.sampled.Row(qi)[0] = static_cast(pos); + activations.sampled.Row(qi)[1] = static_cast(tp.token); + activations.sampled.Row(qi)[2] = hwy::BitCastScalar(tp.prob); + }); + + // Sequentially, because `StreamToken` is not yet thread-safe. + non_eos.Foreach([&](size_t qi) { + const size_t pos = activations.sampled.Row(qi)[0]; + const int token = static_cast(activations.sampled.Row(qi)[1]); + const float prob = + hwy::BitCastScalar(activations.sampled.Row(qi)[2]); + StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch, + /*update_pos=*/true, non_eos); }); } @@ -448,21 +466,25 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; - static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1"); - const size_t worker = 0; // TODO: parallelize + static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1"); + static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general"); // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE3(ctx.profiler, worker, zone); - return Top1OfSoftmax(logits); - }; + return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) + HWY_ATTR -> TokenAndProb { + PROFILER_ZONE3(ctx.profiler, worker, zone_top1); + return Top1OfSoftmax(logits); + }; } // General case: Softmax with top-k sampling. - return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample general"); - RngStream gen(engine, qi); + return [&](size_t qi, size_t pos, Logits logits, + size_t worker) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE3(ctx.profiler, worker, zone_topK); + // We want a different sequence for each batch element and position. + const uint64_t stream = (static_cast(qi) << 32) | pos; + RngStream gen(engine, stream); return FusedSoftmaxAndSampleTopK( logits, runtime_config.top_k, gen, runtime_config.temperature, runtime_config.accept_token, ctx.profiler, worker); @@ -524,12 +546,13 @@ static void GenerateT(const ModelConfig& config, // Stream the last prompt token from each query, fill activations.gen_tokens. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); - const bool pos_plus_1 = false; // during prefill, pos is still correct. + + const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. // In autoregressive mode, we have not prefilled the last token, so do // not advance. const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); - StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, - runtime_config, qbatch, pos_plus_1, update_pos, non_eos); + StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, + config, runtime_config, qbatch, update_pos, non_eos); } size_t max_gen_steps = runtime_config.max_generated_tokens; @@ -546,7 +569,7 @@ static void GenerateT(const ModelConfig& config, for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { Transformer(config, runtime_config, weights, activations, qbatch, env); SampleAndStream(config, runtime_config, weights, sample_token, activations, - qbatch, /*update_pos=*/true, env, non_eos, timing_info); + qbatch, env, non_eos, timing_info); } timing_info.NotifyGenerateDone(); } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 59e3a6c..b2d19ff 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -89,10 +89,10 @@ using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; -// If not empty, SampleFunc is called with the query_idx and logits for the -// next token, which it may modify/overwrite. It returns the next generated -// token together with its probability. -using SampleFunc = std::function; +// If not empty, SampleFunc is called concurrently from worker thread(s) with +// query_idx, pos, logits for the next token (which it may modify/overwrite), +// and worker. It returns the next generated token and its probability. +using SampleFunc = std::function; // If not empty, LayersOutputFunc is called for layer outputs, specified with: // - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence @@ -115,6 +115,7 @@ using ActivationsObserverFunc = struct RuntimeConfig { // If non-null, `batch_stream_token` is called for each token in the batch, // otherwise `stream_token`. `query_idx` is absolute, not batch-relative. + // This is called sequentially from the main thread. bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { PROFILER_ZONE("Gen.StreamToken"); if (batch_stream_token) { From 6e52a835c67f0592204c3dd2ece665534c236d17 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Sun, 7 Sep 2025 22:50:01 -0700 Subject: [PATCH 32/65] Faster startup on tsan: use hierarchical parallelism for BF16 conversion Also re-enable profiler zones PiperOrigin-RevId: 804273899 --- gemma/weights.cc | 66 ++++++++++++++++++++++++++---------------------- ops/matmul-inl.h | 7 +++-- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 8191bd9..b71e6b7 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -383,39 +383,45 @@ static void ReadAllToBF16(const std::vector& tensors, const BlobReader& reader, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16"); - ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) { - PROFILER_ZONE3(ctx.profiler, thread, zone); - const TensorToRead& tensor = tensors[task]; - MatPtr& mat = *tensor.mat; + // Especially TSAN is slow enough to warrant hierarchical parallelism. + const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD + ? ParallelismStrategy::kHierarchical + : ParallelismStrategy::kFlat; + ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0, + [&](uint64_t task, size_t thread) { + PROFILER_ZONE3(ctx.profiler, thread, zone); + const TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; - if (tensor.keep_type) { - HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes, - mat.Packed())); - return; - } + if (tensor.keep_type) { + HWY_ASSERT(reader.file().Read( + tensor.range.offset, tensor.range.bytes, mat.Packed())); + return; + } - // Read to a temporary buffer. - const hwy::AlignedFreeUniquePtr buf = - hwy::AllocateAligned(tensor.range.bytes); - HWY_ASSERT( - reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get())); + // Read to a temporary buffer. + const hwy::AlignedFreeUniquePtr buf = + hwy::AllocateAligned(tensor.range.bytes); + HWY_ASSERT(reader.file().Read(tensor.range.offset, + tensor.range.bytes, buf.get())); - if constexpr (GEMMA_ENABLE_NUQ) { - if (tensor.prev_type == Type::kNUQ) { - return DecompressToBF16(*tensor.mat, buf); - } - } - switch (tensor.prev_type) { - case Type::kF32: - return DecompressToBF16(*tensor.mat, buf); - case Type::kBF16: - return DecompressToBF16(*tensor.mat, buf); - case Type::kSFP: - return DecompressToBF16(*tensor.mat, buf); - default: - HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type)); - } - }); + if constexpr (GEMMA_ENABLE_NUQ) { + if (tensor.prev_type == Type::kNUQ) { + return DecompressToBF16(*tensor.mat, buf); + } + } + switch (tensor.prev_type) { + case Type::kF32: + return DecompressToBF16(*tensor.mat, buf); + case Type::kBF16: + return DecompressToBF16(*tensor.mat, buf); + case Type::kSFP: + return DecompressToBF16(*tensor.mat, buf); + default: + HWY_ABORT("Unsupported type %s", + TypeName(tensor.prev_type)); + } + }); } // Mode == kRead: diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 8b9c011..737feb6 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -770,11 +770,9 @@ class MMState { HWY_NOINLINE void DispatchParallelism(const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { - /* Disabled due to unknown thread-safety issue: static const auto zone = args_.env->ctx.profiler.AddZone("MM.DispatchParallelism"); PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone); - */ MMImpl::DispatchParallelism( args_.options.parallelism, @@ -1053,6 +1051,11 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { + static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); + PROFILER_ZONE3(env.ctx.profiler, + options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), + zone); + const Allocator& allocator = env.ctx.allocator; HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx]; From 06e5da1e22f4524e5605e0943fc7ef1cf34951a3 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 8 Sep 2025 02:23:29 -0700 Subject: [PATCH 33/65] Cleanup: split CacheInfo from Allocator, MatMul helper functions Lift DecompressA out of main autotuner to prevent interference Also use kMaxNR / kNR constants instead of extra args Fix: only require vector alignment, not cache alignment PiperOrigin-RevId: 804333769 --- evals/benchmark_helper.cc | 4 +- ops/matmul-inl.h | 141 +++++++++++++++++++++----------------- ops/matmul.cc | 42 +++++------- ops/matmul.h | 12 ++-- util/allocator.cc | 12 ++-- util/allocator.h | 59 +++++++++------- util/threading_context.cc | 3 +- util/threading_context.h | 5 +- 8 files changed, 152 insertions(+), 126 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 55e99cf..e9fdafb 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -241,8 +241,8 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, dt, cpu100, static_cast(threading.bind), ctx.topology.TopologyString(), ctx.pools.PinString(), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), - ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED, - ctx.allocator.TotalMiB()); + ctx.cache_info.VectorBytes() * 8, CompiledConfig(), + PROFILER_ENABLED, ctx.allocator.TotalMiB()); } } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 737feb6..65dc185 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -21,7 +21,7 @@ #include "compression/types.h" #include "ops/matmul.h" // IWYU pragma: export -#include "util/allocator.h" +#include "util/allocator.h" // CacheInfo #include "util/basics.h" #include "util/mat.h" #include "util/threading_context.h" @@ -566,7 +566,7 @@ class MMKernel { }; // Miscellaneous stateless helper functions. -struct MMImpl { +class MMImpl { // Returns existing entry for the given key or -1. static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { const hwy::Span all_keys = keys.Keys(); @@ -596,6 +596,63 @@ struct MMImpl { return -1; } + public: + static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, + size_t vector_bytes, + MatMulEnv::PerCluster& per_cluster) { + const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); + intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys); + // First time we see this shape/key. + if (HWY_UNLIKELY(index < 0)) { + per_cluster.keys.Append(key, vector_bytes); + + // Invalidates `MMAutoTune::Best()`. + std::vector& per_keys = per_cluster.per_key; + index = per_keys.size(); + per_keys.push_back(MMPerKey()); + } + return per_cluster.per_key[index]; + } + + static void NotifyAutotuneResult(size_t M, size_t K, size_t N, double t0, + const MMConfig& cfg, MatMulEnv& env, + MMAutoTune& tuner) { + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / + hwy::platform::InvariantTicksPerSecond(); + const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA + if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { + fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9, + min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), + StringFromOrder(cfg.Order()), cfg.InnerTasks()); + } + if (HWY_UNLIKELY(env.print_best && tuner.Best())) { + const auto ratio = [&tuner](uint64_t ticks) -> double { + return static_cast(ticks) / + static_cast(tuner.BestTicks()); + }; + const MMConfig& best = *tuner.Best(); + fprintf(stderr, + "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", + M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), + best.KC(), best.NC(), StringFromOrder(best.Order()), + best.InnerTasks(), ratio(tuner.WorstMinTicks()), + ratio(tuner.FirstConfigTicks())); + } + } + + static void EnsureAligned(const MatPtr& A, const size_t vector_bytes) { + // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are + // reliable: the latter returns true for single rows, and the former may + // match `Cols` if the width matches the padding. + // Note that B is packed in matmul_test, but otherwise generally padded. + HWY_ASSERT(hwy::IsAligned(A.RowBytes(0), vector_bytes)); + if (A.Rows() > 1) { + HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes)); + } + } + static size_t Worker(const MMArgs& args) { return args.options.cluster_idx * args.env->ctx.pools.MaxWorkersPerCluster(); @@ -753,14 +810,14 @@ struct MMImpl { // loops over the inner KC and MC. Member variables avoid long argument lists. class MMState { public: - MMState(const Extents2D A, const size_t B_rows, const MMArgs& args, + MMState(size_t M, size_t K, size_t N, const MMArgs& args, const MMConfig& config) : args_(args), - range_n_(0, B_rows), + range_n_(0, N), mr_(config.MR()), - ranges_mc_(config.RangesOfMC(A.rows)), - ranges_kc_(config.RangesOfKC(A.cols)), - ranges_nc_(config.RangesOfNC(B_rows)), + ranges_mc_(config.RangesOfMC(M)), + ranges_kc_(config.RangesOfKC(K)), + ranges_nc_(config.RangesOfNC(N)), order_(config.Order()), inner_tasks_(config.InnerTasks()) {} @@ -783,7 +840,7 @@ class MMState { // Compute size of per-worker storage for `kNR` row ranges of B. Stack // allocation avoids passing a worker index. static constexpr size_t B_stride_max_ = - kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16); + kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); static constexpr size_t B_storage_max_ = kNR * B_stride_max_; // Granularity of `ForN`. B rows produce C columns, so we @@ -1056,88 +1113,48 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone); - const Allocator& allocator = env.ctx.allocator; HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); - MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx]; RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); - const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys); - // First time we see this shape/key. - if (HWY_UNLIKELY(index < 0)) { - per_cluster.keys.Append(key, allocator); - // invalidates `MMAutoTune::Best()` - std::vector& per_keys = per_cluster.per_key; - index = per_keys.size(); - per_keys.push_back(MMPerKey()); - } - MMPerKey& per_key = per_cluster.per_key[index]; + const CacheInfo& cache = env.ctx.cache_info; + MMPerKey& per_key = MMImpl::FindOrAddPerKey( + M, K, N, cache.VectorBytes(), env.per_cluster[options.cluster_idx]); MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add, options); if (HWY_LIKELY(tuner.Best())) { - const MMState state(A.Extents(), B.Rows(), args, *tuner.Best()); + const MMState state(M, K, N, args, *tuner.Best()); const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); state.DispatchParallelism(A_view, B, C_rows); return &per_key; } - // From here, CPU time is negligible except DoMatMul. - - // First call: enumerate all feasible configs. + // Autotuning, first call: enumerate all feasible configs. if (HWY_UNLIKELY(!tuner.HasCandidates())) { - // Ensure matrix dimensions match each other. + // Ensure matrix dimensions match each other (off the hot path). HWY_ASSERT(K == B.Cols()); HWY_ASSERT(M <= kMaxBatchSize); HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(N % kNR == 0); - // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are - // reliable: the latter returns true for single rows, and the former may - // match `Cols` if the width matches the padding. - // Note that B is packed in matmul_test, but otherwise generally padded. - HWY_ASSERT(hwy::IsAligned(A.Row(0), env.ctx.allocator.LineBytes())); - if (A.Rows() > 1) { - HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes())); - } - - tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, - kNR, env.print_config)); + MMImpl::EnsureAligned(A, cache.VectorBytes()); + tuner.SetCandidates( + MMCandidates(cache, M, K, N, sizeof(TC), env.print_config)); } + // (Also auto-tunes, hence outside the timed section to prevent interference.) + const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); + const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMState state(A.Extents(), B.Rows(), args, cfg); - const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); + MMState state(M, K, N, args, cfg); state.DispatchParallelism(A_view, B, C_rows); - const uint64_t t1 = - env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / - hwy::platform::InvariantTicksPerSecond(); - const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA - if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { - fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9, - min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), - StringFromOrder(cfg.Order()), cfg.InnerTasks()); - } - if (HWY_UNLIKELY(env.print_best && tuner.Best())) { - const auto ratio = [per_key](uint64_t ticks) -> double { - return static_cast(ticks) / - static_cast(per_key.autotune.BestTicks()); - }; - const MMConfig& best = *tuner.Best(); - fprintf(stderr, - "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", M, - K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), - best.KC(), best.NC(), StringFromOrder(best.Order()), - best.InnerTasks(), ratio(tuner.WorstMinTicks()), - ratio(tuner.FirstConfigTicks())); - } + MMImpl::NotifyAutotuneResult(M, K, N, t0, cfg, env, tuner); return &per_key; } diff --git a/ops/matmul.cc b/ops/matmul.cc index 35887a5..00330e5 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -62,22 +62,19 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, // and holds most of their arguments in member variables. class GenerateCandidates { public: - GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, - size_t sizeof_TC, size_t max_mr, size_t nr, - bool print_config) - : allocator_(allocator), + GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N, + size_t sizeof_TC, bool print_config) + : cache_(cache), M_(M), K_(K), N_(N), sizeof_TC_(sizeof_TC), - max_mr_(max_mr), - nr_(nr), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line // is likely still in L1, but we expect K > 1000 and might as well round // up to the line size. Both A and B are BF16. - kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), - nc_multiple_(allocator.StepBytes() / sizeof_TC), + kc_multiple_(HWY_MIN(K, cache.LineBytes() / sizeof(BF16))), + nc_multiple_(cache.StepBytes() / sizeof_TC), print_config_(print_config) {} std::vector operator()() const { @@ -127,10 +124,10 @@ class GenerateCandidates { SizeVec all_mr; all_mr.reserve(3); // AVX2's 16 registers are not enough for four rows, but SSE4 may benefit. - if (M_ >= max_mr_ && !is_avx2) all_mr.push_back(max_mr_); + if (M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR); // Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also // enable if not enough rows for 4. - if (M_ >= 2 && (M_ < max_mr_ || (!is_sse && !is_wasm))) { + if (M_ >= 2 && (M_ < kMaxMR || (!is_sse && !is_wasm))) { all_mr.push_back(size_t{2}); } // Even SSE4 usually prefers 2 rows; only enable for single rows. @@ -172,8 +169,8 @@ class GenerateCandidates { // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. const size_t bytes_ab = - allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream)); - const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); + cache_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream)); + const size_t col_bytes = rows_a * sizeof(BF16) + kNR * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); kc_max = RoundDownWithFloor(HWY_MIN(kc_max, kMaxKC), kc_multiple_); kc_max = HWY_MIN(kc_max, K_); @@ -213,14 +210,14 @@ class GenerateCandidates { SizeVec MC(size_t mr, size_t kc, MMOrder order) const { // Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because // it is typically inclusive. - const size_t bytes_b = nr_ * kc * (sizeof(SfpStream) + sizeof(BF16)); + const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16)); // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // packed B. We want `mc * kc` elements of A to fit in L2, alongside // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // partial. - const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes(); - size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); + const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes(); + size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); mc_max = HWY_MIN(mc_max, kMaxBatchSize); HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, M_); @@ -261,7 +258,7 @@ class GenerateCandidates { // Otherwise, leave it unbounded. if (M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); - nc_max = HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), N_); + nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), N_); } HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); @@ -328,15 +325,12 @@ class GenerateCandidates { return inner_tasks; } - const Allocator& allocator_; + const CacheInfo& cache_; const size_t M_; const size_t K_; const size_t N_; const size_t sizeof_TC_; - const size_t max_mr_; - const size_t nr_; - const size_t kc_multiple_; const size_t nc_multiple_; @@ -346,12 +340,10 @@ class GenerateCandidates { } // namespace // Facade to avoid exposing `GenerateCandidates` in the header. -std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, +std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, + size_t N, size_t sizeof_TC, bool print_config) { - return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, - print_config)(); + return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)(); } MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { diff --git a/ops/matmul.h b/ops/matmul.h index 8c7d724..641dad9 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -477,9 +477,9 @@ class MMConfig { static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) -std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, bool print_config); +std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, + size_t N, size_t sizeof_TC, + bool print_config); // State machine for choosing the best `TConfig`, which is `MMConfig` for the // main MatMul autotuner. @@ -619,11 +619,11 @@ class MMKeys { } // Must only be called if not already present in `Keys()`. - void Append(Key key, const Allocator& allocator) { + void Append(Key key, size_t vector_bytes) { // Dynamic allocation because the test checks many more dimensions than // would be reasonable to pre-allocate. DIY for alignment and padding. if (HWY_UNLIKELY(num_unique_ >= capacity_)) { - const size_t NU64 = allocator.VectorBytes() / sizeof(Key); + const size_t NU64 = vector_bytes / sizeof(Key); // Start at one vector so the size is always a multiple of N. if (HWY_UNLIKELY(capacity_ == 0)) { capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below @@ -704,7 +704,7 @@ struct MMArgs { scale(scale), add(add), options(options), - line_bytes(env.ctx.allocator.LineBytes()) {} + line_bytes(env.ctx.cache_info.LineBytes()) {} MatMulEnv* env; MMPerKey* per_key; diff --git a/util/allocator.cc b/util/allocator.cc index f8bfdd5..f99586e 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -130,7 +130,7 @@ size_t DetectTotalMiB(size_t page_bytes) { } // namespace -Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { +CacheInfo::CacheInfo(const BoundedTopology& topology) { line_bytes_ = DetectLineBytes(); // Ensure MaxLineBytes() is an upper bound. HWY_ASSERT(MaxLineBytes() >= LineBytes()); @@ -138,8 +138,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { vector_bytes_ = hwy::VectorBytes(); step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); - base_page_bytes_ = DetectPageSize(); - quantum_bytes_ = step_bytes_; // may overwrite below const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0); if (const hwy::Cache* caches = hwy::DataCaches()) { @@ -153,8 +151,14 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { if (l3_bytes_ == 0) { l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10; } +} - total_mib_ = DetectTotalMiB(base_page_bytes_); +Allocator::Allocator(const BoundedTopology& topology, + const CacheInfo& cache_info, bool enable_bind) + : line_bytes_(cache_info.LineBytes()), + base_page_bytes_(DetectPageSize()), + total_mib_(DetectTotalMiB(base_page_bytes_)) { + quantum_bytes_ = cache_info.StepBytes(); // may overwrite below // Prerequisites for binding: // - supported by the OS (currently Linux only), diff --git a/util/allocator.h b/util/allocator.h index 42e261c..086b6e9 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -77,27 +77,49 @@ using AlignedPtr = std::unique_ptr; template using AlignedClassPtr = std::unique_ptr; -// Both allocation, binding, and row accessors depend on the sizes of memory -// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we -// wrap this in a singleton. A monostate requires explicit initialization, -// which we prefer to avoid because there are many main() functions. -class Allocator { +// Holds cache line size/capacity and vector size. Stored in `ThreadingContext`. +class CacheInfo { public: - // Must be called at least once before any other function. Not thread-safe, - // hence only call this from the main thread. - Allocator(const BoundedTopology& topology, bool enable_bind); + CacheInfo(const BoundedTopology& topology); // Bytes per cache line, or a reasonable guess if unknown. Used to choose // ranges such that there will be no false sharing. size_t LineBytes() const { return line_bytes_; } // Upper bound on `LineBytes()`, for stack allocations. static constexpr size_t MaxLineBytes() { return 256; } + // Bytes per full vector. Used to compute loop steps. size_t VectorBytes() const { return vector_bytes_; } // Work granularity that avoids false sharing and partial vectors. // = HWY_MAX(LineBytes(), VectorBytes()) size_t StepBytes() const { return step_bytes_; } + // L1 and L2 are typically per core. + size_t L1Bytes() const { return l1_bytes_; } + size_t L2Bytes() const { return l2_bytes_; } + // Clusters often share an L3. We return the total size per package. + size_t L3Bytes() const { return l3_bytes_; } + + private: + size_t line_bytes_; + size_t vector_bytes_; + size_t step_bytes_; + + size_t l1_bytes_ = 0; + size_t l2_bytes_ = 0; + size_t l3_bytes_ = 0; +}; + +// NUMA-aware allocation and memory binding. Stored in `ThreadingContext`. +class Allocator { + public: + Allocator(const BoundedTopology& topology, const CacheInfo& cache_info, + bool enable_bind); + + // Used by `AllocateFor`, which only takes an `Allocator` argument, + // hence copy from `CacheInfo`. + size_t LineBytes() const { return line_bytes_; } + // File size multiple required for memory mapping. Also used when binding // memory to NUMA nodes (see `BindB/BindC`). size_t BasePageBytes() const { return base_page_bytes_; } @@ -105,12 +127,6 @@ class Allocator { // Desired allocator alignment: Either StepBytes, or BasePageBytes if NUMA. size_t QuantumBytes() const { return quantum_bytes_; } - // L1 and L2 are typically per core. - size_t L1Bytes() const { return l1_bytes_; } - size_t L2Bytes() const { return l2_bytes_; } - // Clusters often share an L3. We return the total size per package. - size_t L3Bytes() const { return l3_bytes_; } - size_t TotalMiB() const { return total_mib_; } size_t FreeMiB() const; @@ -159,18 +175,11 @@ class Allocator { bool BindMemory(void* p, size_t bytes, size_t node) const; private: - size_t line_bytes_; - size_t vector_bytes_; - size_t step_bytes_; - size_t base_page_bytes_; + const size_t line_bytes_; + const size_t base_page_bytes_; + const size_t total_mib_; + size_t quantum_bytes_; - - size_t l1_bytes_ = 0; - size_t l2_bytes_ = 0; - size_t l3_bytes_ = 0; - - size_t total_mib_; - bool should_bind_ = false; }; diff --git a/util/threading_context.cc b/util/threading_context.cc index 81155c5..90a64d1 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -76,7 +76,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args) topology(BoundedSlice(args.skip_packages, args.max_packages), BoundedSlice(args.skip_clusters, args.max_clusters), BoundedSlice(args.skip_lps, args.max_lps)), - allocator(topology, args.bind != Tristate::kFalse), + cache_info(topology), + allocator(topology, cache_info, args.bind != Tristate::kFalse), pools(topology, allocator, args.max_threads, args.pin) { PROFILER_ZONE("Startup.ThreadingContext autotune"); TunePool(pools.AllPackages()); diff --git a/util/threading_context.h b/util/threading_context.h index 6bd6936..41d0811 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -105,7 +105,10 @@ struct ThreadingContext { // will be 1 regardless of the actual system topology. BoundedTopology topology; - // Ctor depends on `topology` for deciding whether to enable NUMA. + // Ctor depends on `topology` for per-cluster cache sizes. + CacheInfo cache_info; + + // Ctor depends on `topology` (for NUMA) and `cache_info` (for step size). Allocator allocator; // Per-package/cluster/within cluster pools of threads, matching `topology`. From a5ab99e4bab8f10a7372f5cc0f2d14f4ff9fe8b1 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 05:32:20 -0700 Subject: [PATCH 34/65] Memory use reduction: smaller/single MMStorage PiperOrigin-RevId: 804865029 --- gemma/activations.h | 2 +- gemma/gemma.h | 2 +- ops/matmul-inl.h | 7 ++++--- ops/matmul.cc | 6 +----- ops/matmul.h | 8 ++++---- ops/ops-inl.h | 1 + util/basics.h | 2 +- 7 files changed, 13 insertions(+), 15 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 21e5e58..71523e4 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -179,7 +179,7 @@ struct Activations { MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul - MatStorageT logits; + MatStorageT logits; // TODO: BF16 after Softmax supports that. MatStorageT sampled; // batch_size x 3 (padded) // Gated FFW diff --git a/gemma/gemma.h b/gemma/gemma.h index 2f06ab8..491999d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -127,7 +127,7 @@ class QBatch { max_size_(max_size), queries_(queries), size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { - HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`. + HWY_ASSERT(max_size_ <= kMaxBatchSize); HWY_DASSERT(size_ != 0); HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 65dc185..74deb78 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -797,9 +797,10 @@ class MMImpl { return View(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer - // support a separate F32 kernel; most A are already BF16. - const StridedViewBF A_view = - args.env->storage[args.options.cluster_idx].A(A.Extents()); + // support a separate F32 kernel; most A are already BF16. We also only + // have a single MMStorage. + HWY_ASSERT(args.options.cluster_idx == 0); + const StridedViewBF A_view = args.env->storage.A(A.Extents()); DecompressA(A, A_view, args); return A_view; } diff --git a/ops/matmul.cc b/ops/matmul.cc index 00330e5..66ce0df 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -346,14 +346,10 @@ std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)(); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { - // Create storage per cluster. This only applies to in-cluster parallelism. - // For nested and sequential parallelism, a single MMStorage is used. +MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) { const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); per_cluster.resize(num_clusters); - storage.reserve(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - storage.push_back(MMStorage(ctx.allocator)); row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } diff --git a/ops/matmul.h b/ops/matmul.h index 641dad9..c86ecc3 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -306,11 +306,11 @@ using StridedViewD = StridedView; class MMStorage { public: // Compile-time bounds on matrix columns to enable pre-allocating storage - // and reusing it across `MatMul` calls. - static constexpr size_t kMaxK = 64 * 1024; + // and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B. + static constexpr size_t kMaxK = 36 * 1024; MMStorage(const Allocator& allocator) - // 0.5 GiB. Must be padded, see `DoDecompressA`. + // 288 MiB. Must be padded, see `DoDecompressA`. : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd) {} @@ -673,7 +673,7 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - std::vector storage; + MMStorage storage; struct PerCluster { MMKeys keys; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 18ee40f..0c6bd50 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -614,6 +614,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( } // See below for a specialized version for top-1 sampling. +// TODO: support bf16 logits using Decompress2. static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, const size_t worker, float temperature = 1.0f) { diff --git a/util/basics.h b/util/basics.h index 7b1c7d3..0211a0e 100644 --- a/util/basics.h +++ b/util/basics.h @@ -30,7 +30,7 @@ namespace gcpp { -// TODO: extend to 16k after updating non_eos. +// For hwy::BitSet4096. Note that KVs are extremely large for such batches. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; From 34ceee6c308a362ce8502d19580baa65256f4b34 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 05:56:57 -0700 Subject: [PATCH 35/65] Update MatMul comments, removing mention of partial. PiperOrigin-RevId: 804872289 --- ops/matmul-inl.h | 11 +++++------ ops/matmul.h | 9 ++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 74deb78..f2e9c49 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -132,7 +132,7 @@ class MMStoreHorizontalSumsIntoC { // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in // the elements of one V4. We have four independent rows `r`, hence the // code is effectively unrolled, which increases throughput. - // Store to four elements per row of `partial`. + // Store to four elements per row of `C`. // No loop is required because vectors are at least 4*32 bits. const D4 d4; sum0 = MaybeLoad<0>(d4, N, buf); @@ -370,7 +370,7 @@ class MMKernel { // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). - // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. + // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. // `A` and `B` are always BF16, `C` can be F32 or BF16. template static HWY_INLINE void LoopKC(const StridedViewBF A_view, @@ -966,8 +966,7 @@ class MMState { const size_t B_stride = Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); - // Sequential loop over NC/MC/KC, similar to `loop_nc` below - // except for the profiler strings and `out_tag`. + // Similar to `loop_nc` below except for the profiler zone and `MMSetC`. parallel.ForRangesMC_NC( args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, @@ -990,7 +989,7 @@ class MMState { } // Parallel loops over mc/nc blocks of M/range_np, sequential K. - // Fills `mc x nc` sections of `partial`, then `C`, in parallel. + // Accumulates into `mc x nc` sections of `C`. template HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { @@ -1001,7 +1000,7 @@ class MMState { Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read - // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. + // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`. const auto loop_nc = [&](const StridedViewBF B_storage_view, const IndexRange& range_mc, const IndexRange& range_kc, diff --git a/ops/matmul.h b/ops/matmul.h index c86ecc3..946673a 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -332,8 +332,7 @@ class MMStorage { // Autotuning // Naming convention: outer loop first, T suffix means threaded. This refers to -// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost -// `ranges_np` loop across packages is implicit and applies to all of these. +// the loops *around* `A2C0`, which contains loops over mc/kc. // // Parallelizing across K (A/B columns) is undesirable because the resulting // partial dot products require synchronization or reduction across threads. @@ -341,18 +340,18 @@ enum class MMOrder : uint8_t { // Single M, parallel N, sequential K (inside the parallel section to // reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K. kNT_K, - // Specialization of `kNT_K` for a single K task with `kDirect`. + // Specialization of `kNT_K` for a single K task with `MMSetC`. kNT, // Parallelize over blocks of M and N: good when both are large. We no longer // support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as // fast on Zen4. kNT_MT_K, - kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`. + kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`. // Resident C (`kK_M_NT`) should be good for large K relative to M and N. // However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are - // no kN* because we expect M (batch size) to be small relative to K and N. + // no kM* because we expect M (batch size) to be small relative to K and N. }; static inline bool IsBlock(MMOrder order) { From 461a9c7d1b18a269d505769e934af4dc68eaae4d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 07:13:03 -0700 Subject: [PATCH 36/65] Matmul refactoring towards fusion MMLoops: move dispatch code out, use overloads split build target into matmul_env (for MatMulEnv/MMOptions) weights: no longer call BindB Fix potential out of bounds in gemma_batch_bench PiperOrigin-RevId: 804895985 --- .github/workflows/build.yml | 1 + BUILD.bazel | 41 +- evals/gemma_batch_bench.cc | 3 +- examples/simplified_gemma/BUILD.bazel | 2 +- gemma/gemma_args.h | 3 +- gemma/weights.cc | 2 - ops/matmul-inl.h | 541 +++++++++++--------------- ops/matmul.h | 180 ++++++--- ops/matvec-inl.h | 1 - util/mat.h | 3 +- 10 files changed, 388 insertions(+), 389 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2052a82..1512548 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -46,6 +46,7 @@ jobs: -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} -D CMAKE_C_COMPILER_LAUNCHER=ccache -D CMAKE_CXX_COMPILER_LAUNCHER=ccache + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 - name: Build run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4 diff --git a/BUILD.bazel b/BUILD.bazel index 52c2df3..dbe52b7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -238,7 +238,6 @@ cc_library( ":configs", ":gemma_args", ":mat", - ":matmul", ":model_store", ":tensor_info", ":threading_context", @@ -271,14 +270,33 @@ test_suite( ) cc_library( - name = "matmul", + name = "matmul_env", srcs = ["ops/matmul.cc"], hdrs = ["ops/matmul.h"], + deps = [ + ":allocator", + ":basics", + ":configs", + ":mat", + ":threading", + ":threading_context", + "@highway//:bit_set", + "@highway//:hwy", + "@highway//:nanobenchmark", + "@highway//:profiler", + ], +) + +cc_library( + name = "matmul", + # allow depending only on this target, without also matmul_env. + hdrs = ["ops/matmul.h"], textual_hdrs = ["ops/matmul-inl.h"], deps = [ ":allocator", ":basics", ":mat", + ":matmul_env", ":threading", ":threading_context", "//compression:compress", @@ -310,6 +328,7 @@ cc_library( ":basics", ":mat", ":matmul", + ":matmul_env", ":threading_context", "//compression:compress", "//compression:types", @@ -333,11 +352,12 @@ cc_library( ":allocator", ":basics", ":mat", - ":matmul", + ":matmul_env", # MMOptions ":matmul_static", ":threading_context", "//compression:compress", "@highway//:algo", + "@highway//:bit_set", "@highway//:hwy", "@highway//:math", "@highway//:matvec", @@ -434,7 +454,7 @@ cc_test( deps = [ ":basics", ":mat", - ":matmul", + ":matmul_env", ":matmul_static", ":ops", ":threading_context", @@ -462,7 +482,8 @@ cc_test( ], deps = [ ":basics", - ":matmul", + ":matmul_env", + ":matmul_static", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", @@ -495,7 +516,6 @@ cc_library( ":args", ":basics", ":mat", - ":matmul", "//io", "@highway//:hwy", "@highway//:profiler", @@ -523,13 +543,12 @@ cc_library( "gemma/gemma-inl.h", ], deps = [ - ":allocator", ":basics", ":configs", ":gemma_args", ":kv_cache", ":mat", - ":matmul", + ":matmul_env", ":model_store", ":ops", ":threading", @@ -569,7 +588,7 @@ cc_library( ":cross_entropy", ":gemma_args", ":gemma_lib", - ":matmul", + ":matmul_env", ":ops", ":threading_context", ":tokenizer", @@ -600,7 +619,7 @@ cc_library( ":gemma_args", ":gemma_lib", ":kv_cache", - ":matmul", + ":matmul_env", ":threading", ":threading_context", ":tokenizer", @@ -661,7 +680,7 @@ cc_binary( ":benchmark_helper", ":gemma_args", ":gemma_lib", - ":matmul", + ":matmul_env", ":tokenizer", "//compression:types", "//paligemma:image", diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 6d97c61..135c2bb 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -93,7 +93,8 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { if (qpos == questions.size()) qpos = 0; } std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) { + for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); + ++i) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } diff --git a/examples/simplified_gemma/BUILD.bazel b/examples/simplified_gemma/BUILD.bazel index 740ec7d..811906f 100644 --- a/examples/simplified_gemma/BUILD.bazel +++ b/examples/simplified_gemma/BUILD.bazel @@ -15,7 +15,7 @@ cc_library( deps = [ "//:gemma_args", "//:gemma_lib", - "//:matmul", + "//:matmul_env", "//:threading_context", "//:tokenizer", "@highway//:hwy", diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index b2d19ff..3135f50 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -24,8 +24,7 @@ #include #include -#include "io/io.h" // Path -#include "ops/matmul.h" // MMStorage::kMax* +#include "io/io.h" // Path #include "util/args.h" #include "util/basics.h" // Tristate #include "util/mat.h" diff --git a/gemma/weights.cc b/gemma/weights.cc index b71e6b7..425a752 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -30,7 +30,6 @@ #include "gemma/gemma_args.h" #include "gemma/model_store.h" #include "io/blob_store.h" -#include "ops/matmul.h" // MMParallel #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" @@ -338,7 +337,6 @@ static void AllocateAndBindAll(std::vector& tensors, owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, tensor.padding); - BindB(ctx, *tensor.mat, tensor.mat->ElementBytes()); }); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index f2e9c49..21ac14b 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -152,10 +152,10 @@ class MMStoreHorizontalSumsIntoC { // four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is // `MMSetC`, the vectors are written as-is (first call, or small K). // Otherwise, they are partial sums and are accumulated into C. - template , class Tag, typename TC> + template , class Tag, class CRows> HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, const size_t row_c, const size_t col_c, - const MMArgs& args, RowPtrs C_rows) const { + const MMArgs& args, CRows C_rows) const { const V4 vscale = hn::Set(d4, args.scale); HWY_ALIGN static constexpr float kZero[4] = {}; const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); @@ -219,18 +219,24 @@ class MMStoreHorizontalSumsIntoC { } }; // MMStoreHorizontalSumsIntoC -// Stateless, wraps member functions. +// Stateless, wraps member functions. Contains the innermost 2-4 loops. class MMKernel { + // Compute size of per-worker storage for `kNR` row ranges of B. Stack + // allocation avoids passing a worker index. + static constexpr size_t B_stride_max = + kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); + public: // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template + // Called by B3A2C0 and by callers that hoist `A_view`. + template static HWY_INLINE void A2C0(const StridedViewBF A_view, const StridedViewBF B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, - RowPtrs C_rows) { + CRows C_rows) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); const size_t row0 = range_mc.begin(); const size_t mc = range_mc.Num(); @@ -280,6 +286,90 @@ class MMKernel { HWY_DASSERT(imc == mc); } + static constexpr size_t B_storage_max = kNR * B_stride_max; + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, + // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` + // thanks to its large table lookups, and less so on other targets. + template + static StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, + const StridedViewBF B_view) { + const hn::ScalableTag dbf; + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); + + // Neither A nor B require padding because `LoopKC` handles remainders. + if constexpr (hwy::IsSame()) { + return View(B, row_b, range_kc.begin(), range_kc.Num()); + } + + const PackedSpan B_span = B.PaddedSpan(); + + const size_t kc = range_kc.Num(); + const size_t col0 = range_kc.begin(); + + for (size_t r = 0; r < kNR; ++r) { + const size_t packed_ofs = (row_b + r) * B.Stride() + col0; + BF16* HWY_RESTRICT to = B_view.Row(r); + DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + return B_view; + } + + // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads + // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by + // `ForeachKC` and when there is only a single KC task. + template + static void B3A2C0(const StridedViewBF A, const MatPtrT& B, + const MMArgs& args, const IndexRange& range_mc, + const IndexRange& range_kc, const IndexRange& range_nc, + size_t mr, Tag out_tag, CRows C_rows) { + HWY_ALIGN BF16 B_storage[B_storage_max]; + + const size_t kc = range_kc.Num(); + const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); + + const size_t B_stride = + Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes); + const StridedViewBF B_storage_view(B_storage, kc, B_stride); + + for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); + row_b += kNR) { + StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); + A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); + } + } + + template + static void ForeachKC(const StridedViewBF A, const MatPtrT& B, + const MMArgs& args, const IndexRange& range_mc, + const IndexRangePartition& ranges_kc, + const IndexRange& range_nc, size_t mr, CRows C_rows) { + // Peel off the first iteration of the kc loop: avoid zero-initializing `C` + // by writing directly into it, and later accumulating into it. + ranges_kc.VisitFirst([&](const IndexRange& range_kc) { + B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows); + }); + ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { + B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows); + }); + } + private: // Element-wise multiplies a vector from one row of A with `kNR` vectors, // each from a row of transposed B, and adds them to `kNR` fp32 `Cc` @@ -372,11 +462,11 @@ class MMKernel { // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. // `A` and `B` are always BF16, `C` can be F32 or BF16. - template + template static HWY_INLINE void LoopKC(const StridedViewBF A_view, const StridedViewBF B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, - const MMArgs& args, RowPtrs C_rows) { + const MMArgs& args, CRows C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); @@ -601,7 +691,7 @@ class MMImpl { size_t vector_bytes, MatMulEnv::PerCluster& per_cluster) { const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys); + intptr_t index = IndexOfKey(key, per_cluster.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { per_cluster.keys.Append(key, vector_bytes); @@ -614,9 +704,9 @@ class MMImpl { return per_cluster.per_key[index]; } - static void NotifyAutotuneResult(size_t M, size_t K, size_t N, double t0, - const MMConfig& cfg, MatMulEnv& env, - MMAutoTune& tuner) { + static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N, + double t0, MMAutoTune& tuner, + const MMConfig& cfg) { const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / @@ -653,39 +743,16 @@ class MMImpl { } } - static size_t Worker(const MMArgs& args) { - return args.options.cluster_idx * - args.env->ctx.pools.MaxWorkersPerCluster(); - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - - template - static void DispatchParallelism(ParallelismStrategy parallelism, - const Func& func) { - switch (parallelism) { - case ParallelismStrategy::kHierarchical: - return func(MMParallelHierarchical()); - case ParallelismStrategy::kNone: - return func(MMParallelNone()); - case ParallelismStrategy::kWithinCluster: - return func(MMParallelWithinCluster()); - default: - HWY_UNREACHABLE; - } + static size_t Worker(const MatMulEnv& env, size_t cluster_idx) { + return cluster_idx * env.ctx.pools.MaxWorkersPerCluster(); } // Decompresses all `M x K` from `A` into padded BF16 `A_view`. static HWY_NOINLINE void DoDecompressA(const MatPtrT& A, const StridedViewBF A_view, - MMParA par_a, const MMArgs& args) { + MMAutoTune& autotune, + MMParA par_a, const MatMulEnv& env, + const MMOptions& options) { const IndexRange all_M(0, A.Rows()); const IndexRange all_K(0, A.Cols()); HWY_DASSERT(all_K.Num() == A_view.Cols()); @@ -693,13 +760,13 @@ class MMImpl { const hn::ScalableTag dbf; const size_t NBF = hn::Lanes(dbf); - static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA"); + static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); const auto do_range = [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args); + mm_zone.MaybeEnter(worker, zone, env, &autotune); const size_t col0 = range_K.begin(); const size_t cols = range_K.Num(); @@ -722,7 +789,7 @@ class MMImpl { switch (par_a) { case MMParA::kNone: - do_range(all_M, all_K, MMImpl::Worker(args)); + do_range(all_M, all_K, Worker(env, options.cluster_idx)); break; case MMParA::kK1: @@ -732,27 +799,26 @@ class MMImpl { // At least one vector, otherwise DecompressAndZeroPad will add // padding, which might overwrite neighboring tasks. Also a whole cache // line to avoid false sharing. - const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16)); + const size_t multiple_K = + HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16)); - DispatchParallelism( - args.options.parallelism, [&](const auto& parallel) { - parallel.ForN(args.env->ctx, all_K, multiple_K, inner_tasks, - args.options.cluster_idx, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); - }); + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks, + options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); + }); break; } case MMParA::kM: - DispatchParallelism( - args.options.parallelism, [&](const auto& parallel) { - parallel.ForRangeMC( - args.env->ctx, all_M, args.options.cluster_idx, - [&](size_t row_a, size_t worker) { - do_range(IndexRange(row_a, row_a + 1), all_K, worker); - }); - }); + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx, + [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, + worker); + }); + }); break; } } @@ -760,11 +826,11 @@ class MMImpl { // Autotuning wrapper for `DoDecompressA`. static HWY_INLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view, - const MMArgs& args) { - MMAutoTune& autotune = args.per_key->autotune_par_a; - + MMAutoTune& autotune, + const MatMulEnv& env, + const MMOptions& options) { if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, *autotune.Best(), args); + return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options); } // First call: generate candidates. @@ -777,11 +843,11 @@ class MMImpl { const MMParA& par_a = autotune.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, par_a, args); + DoDecompressA(A, A_view, autotune, par_a, env, options); const uint64_t t1 = - args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); - if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) { + if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) { fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), static_cast(min_elapsed) / hwy::platform::InvariantTicksPerSecond() * 1E6); @@ -790,299 +856,148 @@ class MMImpl { template static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, - const MMArgs& args) { + MMAutoTune& autotune, + const MatMulEnv& env, + MMOptions options) { if constexpr (IsBF16()) { // We can use a view, regardless of columns/padding, because `LoopKC` // supports non-vector multiples. - return View(A, 0, 0, A.Cols()); + return MMKernel::View(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. We also only // have a single MMStorage. - HWY_ASSERT(args.options.cluster_idx == 0); - const StridedViewBF A_view = args.env->storage.A(A.Extents()); - DecompressA(A, A_view, args); + HWY_ASSERT(options.cluster_idx == 0); + const StridedViewBF A_view = env.storage.A(A.Extents()); + DecompressA(A, A_view, autotune, env, options); return A_view; } } }; -// Contains several variants of the outer M/N/K loops, and calls `A2C0` which -// loops over the inner KC and MC. Member variables avoid long argument lists. -class MMState { +// Defines several variants of the outer M/N/K loops (see `MMOrder`). +class MMLoops { public: - MMState(size_t M, size_t K, size_t N, const MMArgs& args, - const MMConfig& config) - : args_(args), - range_n_(0, N), - mr_(config.MR()), - ranges_mc_(config.RangesOfMC(M)), - ranges_kc_(config.RangesOfKC(K)), - ranges_nc_(config.RangesOfNC(N)), - order_(config.Order()), - inner_tasks_(config.InnerTasks()) {} - // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. template - HWY_NOINLINE void DispatchParallelism(const StridedViewBF A, - const MatPtrT& B, - RowPtrs C_rows) const { - static const auto zone = - args_.env->ctx.profiler.AddZone("MM.DispatchParallelism"); - PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone); + static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); + PROFILER_ZONE3(args.env.ctx.profiler, + MMImpl::Worker(args.env, args.options.cluster_idx), zone); - MMImpl::DispatchParallelism( - args_.options.parallelism, - [&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); }); + DispatchParallelism( + args.options.parallelism, [&](const auto& parallel) HWY_ATTR { + DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { + Loop(order, parallel, A, B, C_rows, args); + }); + }); } private: - // Compute size of per-worker storage for `kNR` row ranges of B. Stack - // allocation avoids passing a worker index. - static constexpr size_t B_stride_max_ = - kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); - static constexpr size_t B_storage_max_ = kNR * B_stride_max_; - // Granularity of `ForN`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. - size_t MultipleN(size_t sizeof_TC) const { - return HWY_MAX(kNR, args_.line_bytes / sizeof_TC); - } - - // B is decompressed several call layers lower, but not all member functions - // depend on `TB`, so pass it as an argument instead of templating the class. - template - HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy, - const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows) const { - switch (order_) { - case MMOrder::kNT: - return DoNT(parallel_policy, A, B, C_rows); - case MMOrder::kNT_K: - return DoNT_K(parallel_policy, A, B, C_rows); - case MMOrder::kNT_MT: - return DoNT_MT(parallel_policy, A, B, C_rows); - case MMOrder::kNT_MT_K: - return DoNT_MT_K(parallel_policy, A, B, C_rows); - default: - HWY_UNREACHABLE; - } + static size_t MultipleN(size_t sizeof_TC, size_t line_bytes) { + return HWY_MAX(kNR, line_bytes / sizeof_TC); } // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); - HWY_DASSERT(ranges_mc_.NumTasks() == 1); - HWY_DASSERT(ranges_kc_.NumTasks() == 1); - const IndexRange& range_M = ranges_mc_.Range(0); - const IndexRange& range_K = ranges_kc_.Range(0); + template + static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); + HWY_DASSERT(args.ranges_mc.NumTasks() == 1); + HWY_DASSERT(args.ranges_kc.NumTasks() == 1); + const IndexRange& range_M = args.ranges_mc.Range(0); + const IndexRange& range_K = args.ranges_kc.Range(0); const size_t K = range_K.Num(); const StridedViewBF A_view = A.View(range_M.begin(), 0, K); const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); + Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes); - // Similar to `loop_nc` below, but here we hoisted `A_view`. + // Similar to `B3A2C0`, but here we hoisted `A_view`. parallel.ForN( - args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, - args_.options.cluster_idx, + args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), + args.inner_tasks, args.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS + HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS const StridedViewBF B_storage_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = - DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), - args_, C_rows); + MMKernel::DecompressB(B, row_b, range_K, B_storage_view); + MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), + args, C_rows); } }); } // Single M range, parallel N, sequential K. Sets C, then accumulates. - template - HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); - HWY_DASSERT(ranges_mc_.NumTasks() == 1); - const IndexRange& range_mc = ranges_mc_.Range(0); + template + static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); + HWY_DASSERT(args.ranges_mc.NumTasks() == 1); + const IndexRange& range_mc = args.ranges_mc.Range(0); - // Loop over NC/MC/KC, called from the outer loops over K/N. - // C++14 generic lambda enables hoisting branches via template - // argument, while also capturing to avoid long argument lists. - const auto loop_nc = [&](BF16* B_storage, const IndexRange& range_kc, - const IndexRange& range_nc, - auto out_tag) HWY_ATTR { - const size_t kc = range_kc.Num(); - const StridedViewBF A_view = - A.View(range_mc.begin(), range_kc.begin(), kc); - const StridedViewBF B_storage_view( - B_storage, kc, - Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes)); - - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); - } - }; - - parallel.ForN( - args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, - args_.options.cluster_idx, - [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { - MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - - // Peel off the first iteration of the kc loop: avoid - // zero-initializing `partial` by writing into it. - ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMSetC()); - }); - ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMAddC()); - }); - }); + parallel.ForN(args.env.ctx, args.range_n, + MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks, + args.options.cluster_idx, + [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, + range_nc, args.mr, C_rows); + }); } // Parallel loops over mc/nc blocks of M/range_n, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); - HWY_DASSERT(ranges_kc_.NumTasks() == 1); - const IndexRange& range_K = ranges_kc_.Range(0); - const size_t K = range_K.Num(); - const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); + template + static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); + HWY_DASSERT(args.ranges_kc.NumTasks() == 1); + const IndexRange& range_K = args.ranges_kc.Range(0); - // Similar to `loop_nc` below except for the profiler zone and `MMSetC`. parallel.ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, + args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - const StridedViewBF A_view = A.View(range_mc.begin(), 0, K); - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const StridedViewBF B_storage_view(B_storage, K, B_stride); - - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - const StridedViewBF B_view = - DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), - args_, C_rows); - } + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr, + MMSetC(), C_rows); }); } // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Accumulates into `mc x nc` sections of `C`. - template - HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); - const size_t kc_max = ranges_kc_.TaskSize(); - HWY_DASSERT(kc_max <= kMaxKC); - const size_t B_stride = - Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); - // Sequential loop over NC/MC/KC, for when the M/N loops are - // already parallel. This is B3A2C0 in MOMMS terminology: we read - // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`. - const auto loop_nc = [&](const StridedViewBF B_storage_view, - const IndexRange& range_mc, - const IndexRange& range_kc, - const IndexRange& range_nc, - auto out_tag) HWY_ATTR { - const size_t kc = range_kc.Num(); - const StridedViewBF A_view = - A.View(range_mc.begin(), range_kc.begin(), kc); + template + static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); - } - }; // loop_nc parallel.ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, + args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const StridedViewBF B_storage_view(B_storage, kc_max, B_stride); - - // Peel off the first iteration of the kc loop: avoid - // zero-initializing `C` by writing into it. - ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMSetC()); - }); - ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMAddC()); - }); + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc, + args.mr, C_rows); }); } - - // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, - // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` - // thanks to its large table lookups, and less so on other targets. - template - HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, - const IndexRange& range_kc, - const StridedViewBF B_view) const { - const hn::ScalableTag dbf; - HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); - - // Neither A nor B require padding because `LoopKC` handles remainders. - if constexpr (hwy::IsSame()) { - return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num()); - } - - const PackedSpan B_span = B.PaddedSpan(); - - const size_t kc = range_kc.Num(); - const size_t col0 = range_kc.begin(); - - for (size_t r = 0; r < kNR; ++r) { - const size_t packed_ofs = (row_b + r) * B.Stride() + col0; - BF16* HWY_RESTRICT to = B_view.Row(r); - DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - return B_view; - } - - const MMArgs args_; // copy for locality - - const IndexRange range_n_; - // From MMConfig: - const size_t mr_; - const IndexRangePartition ranges_mc_; - const IndexRangePartition ranges_kc_; - const IndexRangePartition ranges_nc_; - const MMOrder order_; - const size_t inner_tasks_; -}; // MMState +}; // MMLoops // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // @@ -1109,29 +1024,30 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); + const size_t cluster_idx = options.cluster_idx; + HWY_DASSERT(cluster_idx < env.row_ptrs.size()); PROFILER_ZONE3(env.ctx.profiler, - options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), - zone); + cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone); - HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); - RowPtrs C_rows = - GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); const CacheInfo& cache = env.ctx.cache_info; - MMPerKey& per_key = MMImpl::FindOrAddPerKey( - M, K, N, cache.VectorBytes(), env.per_cluster[options.cluster_idx]); - MMAutoTune& tuner = per_key.autotune; + MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(), + env.per_cluster[cluster_idx]); - const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), - add, options); + // (Also auto-tunes, hence outside the timed section to prevent interference.) + const StridedViewBF A_view = + MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options); + + MMAutoTune& tuner = per_key.autotune; if (HWY_LIKELY(tuner.Best())) { - const MMState state(M, K, N, args, *tuner.Best()); - const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); - state.DispatchParallelism(A_view, B, C_rows); + const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), + add, options, tuner, *tuner.Best()); + MMLoops::Dispatch(A_view, B, C_rows, args); return &per_key; } @@ -1147,14 +1063,13 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMCandidates(cache, M, K, N, sizeof(TC), env.print_config)); } - // (Also auto-tunes, hence outside the timed section to prevent interference.) - const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); - const MMConfig& cfg = tuner.NextConfig(); + const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), + add, options, tuner, cfg); + const uint64_t t0 = hwy::timer::Start(); - MMState state(M, K, N, args, cfg); - state.DispatchParallelism(A_view, B, C_rows); - MMImpl::NotifyAutotuneResult(M, K, N, t0, cfg, env, tuner); + MMLoops::Dispatch(A_view, B, C_rows, args); + MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg); return &per_key; } diff --git a/ops/matmul.h b/ops/matmul.h index 946673a..915970c 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -21,7 +21,7 @@ #include #include -#include // std::unique_ptr +#include #include // IWYU pragma: begin_exports @@ -54,13 +54,58 @@ HWY_INLINE_VAR constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. HWY_INLINE_VAR constexpr size_t kMaxMR = 4; +HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink? + // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; +// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. +// Also used to decompress B, hence non-const. +#pragma pack(push, 1) // power of two size +template +class StridedView { + public: + StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), + cols_(static_cast(cols)), + stride_(static_cast(stride)) { + HWY_DASSERT(stride >= cols); + } + + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } + size_t Cols() const { return static_cast(cols_); } + + size_t Stride() const { return static_cast(stride_); } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + StridedView View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < Cols()); + HWY_DASSERT(cols <= Cols() - c); + return StridedView(Row(r) + c, cols, stride_); + } + + private: + T* HWY_RESTRICT row0_; + uint32_t cols_; + uint32_t stride_; +}; +#pragma pack(pop) + +using StridedViewBF = StridedView; +using StridedViewD = StridedView; + +using MMFused = std::function; + struct MMOptions { uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; + + MMFused fused; }; // Policy classes for parallelism, implementing some of `ParallelismStrategy`. @@ -260,49 +305,26 @@ struct MMParallelHierarchical { } }; +template +void DispatchParallelism(ParallelismStrategy parallelism, const Func& func, + Args&&... args) { + switch (parallelism) { + case ParallelismStrategy::kNone: + return func(MMParallelNone(), std::forward(args)...); + case ParallelismStrategy::kWithinCluster: + return func(MMParallelWithinCluster(), std::forward(args)...); + case ParallelismStrategy::kHierarchical: + return func(MMParallelHierarchical(), std::forward(args)...); + default: + HWY_UNREACHABLE; + } +} + void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); // C is BF16/float. void BindC(ThreadingContext& ctx, MatPtr& C); -// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. -// Also used to decompress B, hence non-const. -#pragma pack(push, 1) // power of two size -template -class StridedView { - public: - StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), - cols_(static_cast(cols)), - stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); - } - - T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } - size_t Cols() const { return static_cast(cols_); } - - size_t Stride() const { return static_cast(stride_); } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - StridedView View(size_t r, size_t c, size_t cols) const { - HWY_DASSERT(c < Cols()); - HWY_DASSERT(cols <= Cols() - c); - return StridedView(Row(r) + c, cols, stride_); - } - - private: - T* HWY_RESTRICT row0_; - uint32_t cols_; - uint32_t stride_; -}; -#pragma pack(pop) - -using StridedViewBF = StridedView; -using StridedViewD = StridedView; - +// For A. class MMStorage { public: // Compile-time bounds on matrix columns to enable pre-allocating storage @@ -354,6 +376,28 @@ enum class MMOrder : uint8_t { // no kM* because we expect M (batch size) to be small relative to K and N. }; +// Tag types for `DispatchOrder`. +struct MMOrderNT_K {}; +struct MMOrderNT {}; +struct MMOrderNT_MT_K {}; +struct MMOrderNT_MT {}; + +template +void DispatchOrder(MMOrder order, const Func& func, Args&&... args) { + switch (order) { + case MMOrder::kNT_K: + return func(MMOrderNT_K(), std::forward(args)...); + case MMOrder::kNT: + return func(MMOrderNT(), std::forward(args)...); + case MMOrder::kNT_MT_K: + return func(MMOrderNT_MT_K(), std::forward(args)...); + case MMOrder::kNT_MT: + return func(MMOrderNT_MT(), std::forward(args)...); + default: + HWY_UNREACHABLE; + } +} + static inline bool IsBlock(MMOrder order) { return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT; } @@ -693,26 +737,46 @@ struct MatMulEnv { std::vector> row_ptrs; }; -// Arguments to MatMul() that are independent of the A/B/C types. -// Reduces register pressure compared to individual values/references. +// Arguments to MatMul() that are independent of the A/B/C types. Reduces +// register pressure compared to individual values/references. Also used for +// passing through `DispatchOrder`. struct MMArgs { - MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add, MMOptions options) - : env(&env), - per_key(&per_key), + MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale, + const float* HWY_RESTRICT add, MMOptions options, + const MMAutoTune& autotune, const MMConfig& config) + : env(env), + line_bytes(env.ctx.cache_info.LineBytes()), + + range_n(0, N), scale(scale), add(add), options(options), - line_bytes(env.ctx.cache_info.LineBytes()) {} - MatMulEnv* env; - MMPerKey* per_key; + autotune(autotune), + mr(config.MR()), + ranges_mc(config.RangesOfMC(M)), + ranges_kc(config.RangesOfKC(K)), + ranges_nc(config.RangesOfNC(N)), + order(config.Order()), + inner_tasks(config.InnerTasks()) {} - double scale; + MatMulEnv& env; + const size_t line_bytes; // from `env`, for `Stride`. + + // MatMul arguments: + const IndexRange range_n; // entire N + const double scale; const float* HWY_RESTRICT add; + const MMOptions options; - MMOptions options; - size_t line_bytes; + const MMAutoTune& autotune; // for `MaybeEnter` + // From `MMConfig`: + const size_t mr; + const IndexRangePartition ranges_mc; + const IndexRangePartition ranges_kc; + const IndexRangePartition ranges_nc; + const MMOrder order; + const size_t inner_tasks; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. @@ -729,11 +793,12 @@ class MMZone { } } - // `name` must be a string literal. + template void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone, - const MMArgs& args) { - if (args.per_key->WantProfile()) { - new (&data_) Zone(args.env->ctx.profiler, thread, zone); + const MatMulEnv& env, const AutoTune* auto_tune) { + // Only if enabled and autotuning finished. + if (PROFILER_ENABLED && auto_tune->Best()) { + new (&data_) Zone(env.ctx.profiler, thread, zone); HWY_DASSERT(data_ != 0); } } @@ -744,7 +809,8 @@ class MMZone { }; #else struct MMZone { - void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {} + void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MatMulEnv&, + const void*) {} }; #endif // PROFILER_ENABLED diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 8be84ec..c8feda9 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -37,7 +37,6 @@ #include "compression/compress-inl.h" #include "ops/dot-inl.h" -#include "ops/matmul.h" #include "util/mat.h" // MatPtrT #include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/matvec/matvec-inl.h" diff --git a/util/mat.h b/util/mat.h index c084e81..c8a4617 100644 --- a/util/mat.h +++ b/util/mat.h @@ -40,9 +40,10 @@ class RowPtrs { public: RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {} - T* HWY_RESTRICT operator[](size_t row_idx) const { + T* HWY_RESTRICT Row(size_t row_idx) const { return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]); } + T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); } private: uint8_t** row_ptrs_; From 24b1760f03f53d85bdade30a4d3f2c34122e7856 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 07:55:39 -0700 Subject: [PATCH 37/65] Refactor: move Worker to ThreadingContext, factor out MMDecompress PiperOrigin-RevId: 804909921 --- ops/matmul-inl.h | 361 ++++++++++++++++++++------------------- ops/matmul.h | 18 +- util/threading_context.h | 14 +- 3 files changed, 199 insertions(+), 194 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 21ac14b..bf7bd68 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -219,6 +219,181 @@ class MMStoreHorizontalSumsIntoC { } }; // MMStoreHorizontalSumsIntoC +// Stateless, wraps member functions. +class MMDecompress { + public: + // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, + // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` + // thanks to its large table lookups, and less so on other targets. + template + static StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, + const StridedViewBF B_view) { + const hn::ScalableTag dbf; + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); + + // Neither A nor B require padding because `LoopKC` handles remainders. + if constexpr (hwy::IsSame()) { + return View(B, row_b, range_kc.begin(), range_kc.Num()); + } + + const PackedSpan B_span = B.PaddedSpan(); + + const size_t kc = range_kc.Num(); + const size_t col0 = range_kc.begin(); + + for (size_t r = 0; r < kNR; ++r) { + const size_t packed_ofs = (row_b + r) * B.Stride() + col0; + BF16* HWY_RESTRICT to = B_view.Row(r); + DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + return B_view; + } + + template + static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, + MMAutoTune& autotune, + const MatMulEnv& env, + MMOptions options) { + if constexpr (IsBF16()) { + // We can use a view, regardless of columns/padding, because + // `MMKernel::LoopKC` supports non-vector multiples. + return View(A, 0, 0, A.Cols()); + } else { + // Always decompress. To reduce code size/compile time, we no longer + // support a separate F32 kernel; most A are already BF16. We also only + // have a single MMStorage. + HWY_ASSERT(options.cluster_idx == 0); + const StridedViewBF A_view = env.storage.A(A.Extents()); + AutotuneDecompressA(A, A_view, autotune, env, options); + return A_view; + } + } + + private: + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + static HWY_NOINLINE void DecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMAutoTune& autotune, + MMParA par_a, const MatMulEnv& env, + const MMOptions& options) { + const IndexRange all_M(0, A.Rows()); + const IndexRange all_K(0, A.Cols()); + HWY_DASSERT(all_K.Num() == A_view.Cols()); + + const hn::ScalableTag dbf; + const size_t NBF = hn::Lanes(dbf); + + static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); + + const auto do_range = + [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) + HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, env, &autotune); + + const size_t col0 = range_K.begin(); + const size_t cols = range_K.Num(); + // Must be a vector multiple, or the last range before row + // padding, otherwise `DecompressAndZeroPad` overwrites neighbors. + HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); + for (size_t row_a : range_M) { + const PackedSpan from = + MakeSpan(A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; + DecompressAndZeroPad(dbf, from, 0, to, cols); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + }; + + switch (par_a) { + case MMParA::kNone: + do_range(all_M, all_K, env.ctx.Worker(options.cluster_idx)); + break; + + case MMParA::kK1: + case MMParA::kK2: + case MMParA::kK4: { + const size_t inner_tasks = static_cast(par_a); + // At least one vector, otherwise DecompressAndZeroPad will add + // padding, which might overwrite neighboring tasks. Also a whole cache + // line to avoid false sharing. + const size_t multiple_K = + HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16)); + + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks, + options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); + }); + break; + } + case MMParA::kM: + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx, + [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, + worker); + }); + }); + break; + } + } + + // Autotuning wrapper for `DoDecompressA`. + static HWY_INLINE void AutotuneDecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMAutoTune& autotune, + const MatMulEnv& env, + const MMOptions& options) { + if (HWY_LIKELY(autotune.Best())) { + return DecompressA(A, A_view, autotune, *autotune.Best(), env, options); + } + + // First call: generate candidates. + if (HWY_UNLIKELY(!autotune.HasCandidates())) { + const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; + std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, + other}; + autotune.SetCandidates(candidates); + } + + const MMParA& par_a = autotune.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DecompressA(A, A_view, autotune, par_a, env, options); + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); + if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) { + fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), + static_cast(min_elapsed) / + hwy::platform::InvariantTicksPerSecond() * 1E6); + } + } +}; // MMDecompress + // Stateless, wraps member functions. Contains the innermost 2-4 loops. class MMKernel { // Compute size of per-worker storage for `kNR` row ranges of B. Stack @@ -288,49 +463,6 @@ class MMKernel { static constexpr size_t B_storage_max = kNR * B_stride_max; - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - - // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, - // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` - // thanks to its large table lookups, and less so on other targets. - template - static StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, - const IndexRange& range_kc, - const StridedViewBF B_view) { - const hn::ScalableTag dbf; - HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); - - // Neither A nor B require padding because `LoopKC` handles remainders. - if constexpr (hwy::IsSame()) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); - } - - const PackedSpan B_span = B.PaddedSpan(); - - const size_t kc = range_kc.Num(); - const size_t col0 = range_kc.begin(); - - for (size_t r = 0; r < kNR; ++r) { - const size_t packed_ofs = (row_b + r) * B.Stride() + col0; - BF16* HWY_RESTRICT to = B_view.Row(r); - DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - return B_view; - } - // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by // `ForeachKC` and when there is only a single KC task. @@ -350,7 +482,8 @@ class MMKernel { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); + StridedViewBF B_view = + MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); } } @@ -742,137 +875,6 @@ class MMImpl { HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes)); } } - - static size_t Worker(const MatMulEnv& env, size_t cluster_idx) { - return cluster_idx * env.ctx.pools.MaxWorkersPerCluster(); - } - - // Decompresses all `M x K` from `A` into padded BF16 `A_view`. - static HWY_NOINLINE void DoDecompressA(const MatPtrT& A, - const StridedViewBF A_view, - MMAutoTune& autotune, - MMParA par_a, const MatMulEnv& env, - const MMOptions& options) { - const IndexRange all_M(0, A.Rows()); - const IndexRange all_K(0, A.Cols()); - HWY_DASSERT(all_K.Num() == A_view.Cols()); - - const hn::ScalableTag dbf; - const size_t NBF = hn::Lanes(dbf); - - static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); - - const auto do_range = - [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) - HWY_ATTR { - MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, env, &autotune); - - const size_t col0 = range_K.begin(); - const size_t cols = range_K.Num(); - // Must be a vector multiple, or the last range before row - // padding, otherwise `DecompressAndZeroPad` overwrites neighbors. - HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); - for (size_t row_a : range_M) { - const PackedSpan from = - MakeSpan(A.Row(row_a) + col0, cols); - BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; - DecompressAndZeroPad(dbf, from, 0, to, cols); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - }; - - switch (par_a) { - case MMParA::kNone: - do_range(all_M, all_K, Worker(env, options.cluster_idx)); - break; - - case MMParA::kK1: - case MMParA::kK2: - case MMParA::kK4: { - const size_t inner_tasks = static_cast(par_a); - // At least one vector, otherwise DecompressAndZeroPad will add - // padding, which might overwrite neighboring tasks. Also a whole cache - // line to avoid false sharing. - const size_t multiple_K = - HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16)); - - DispatchParallelism(options.parallelism, [&](const auto& parallel) { - parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks, - options.cluster_idx, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); - }); - break; - } - case MMParA::kM: - DispatchParallelism(options.parallelism, [&](const auto& parallel) { - parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx, - [&](size_t row_a, size_t worker) { - do_range(IndexRange(row_a, row_a + 1), all_K, - worker); - }); - }); - break; - } - } - - // Autotuning wrapper for `DoDecompressA`. - static HWY_INLINE void DecompressA(const MatPtrT& A, - const StridedViewBF A_view, - MMAutoTune& autotune, - const MatMulEnv& env, - const MMOptions& options) { - if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options); - } - - // First call: generate candidates. - if (HWY_UNLIKELY(!autotune.HasCandidates())) { - const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; - std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, - other}; - autotune.SetCandidates(candidates); - } - - const MMParA& par_a = autotune.NextConfig(); - const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, autotune, par_a, env, options); - const uint64_t t1 = - env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); - if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) { - fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), - static_cast(min_elapsed) / - hwy::platform::InvariantTicksPerSecond() * 1E6); - } - } - - template - static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, - MMAutoTune& autotune, - const MatMulEnv& env, - MMOptions options) { - if constexpr (IsBF16()) { - // We can use a view, regardless of columns/padding, because `LoopKC` - // supports non-vector multiples. - return MMKernel::View(A, 0, 0, A.Cols()); - } else { - // Always decompress. To reduce code size/compile time, we no longer - // support a separate F32 kernel; most A are already BF16. We also only - // have a single MMStorage. - HWY_ASSERT(options.cluster_idx == 0); - const StridedViewBF A_view = env.storage.A(A.Extents()); - DecompressA(A, A_view, autotune, env, options); - return A_view; - } - } }; // Defines several variants of the outer M/N/K loops (see `MMOrder`). @@ -885,7 +887,7 @@ class MMLoops { RowPtrs C_rows, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, - MMImpl::Worker(args.env, args.options.cluster_idx), zone); + args.env.ctx.Worker(args.options.cluster_idx), zone); DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { @@ -931,7 +933,7 @@ class MMLoops { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = - MMKernel::DecompressB(B, row_b, range_K, B_storage_view); + MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), args, C_rows); } @@ -1026,8 +1028,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); const size_t cluster_idx = options.cluster_idx; HWY_DASSERT(cluster_idx < env.row_ptrs.size()); - PROFILER_ZONE3(env.ctx.profiler, - cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone); + PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); @@ -1041,7 +1042,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = - MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options); + MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); MMAutoTune& tuner = per_key.autotune; if (HWY_LIKELY(tuner.Best())) { diff --git a/ops/matmul.h b/ops/matmul.h index 915970c..a85d192 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -116,7 +116,7 @@ struct MMParallelNone { size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx, const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); func(range_n, worker); } @@ -125,7 +125,7 @@ struct MMParallelNone { const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, size_t cluster_idx, const Func& func) const { - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { const IndexRange range_mc = ranges_mc.Range(i); @@ -139,7 +139,7 @@ struct MMParallelNone { template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { func(row_a, worker); } @@ -154,7 +154,7 @@ struct MMParallelWithinCluster { const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); const IndexRangePartition worker_ranges = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); @@ -171,7 +171,7 @@ struct MMParallelWithinCluster { const Func& func) const { const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { @@ -192,7 +192,7 @@ struct MMParallelWithinCluster { size_t cluster_idx, const Func& func) const { const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); cluster.Run( range_mc.begin(), range_mc.end(), @@ -233,8 +233,7 @@ struct MMParallelHierarchical { n_ranges, all_clusters, [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t cluster_base = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t cluster_base = ctx.Worker(cluster_idx); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( n_range, cluster.NumWorkers() * inner_tasks, n_multiple); @@ -284,8 +283,7 @@ struct MMParallelHierarchical { ParallelizeOneRange( ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { - const size_t cluster_base = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t cluster_base = ctx.Worker(cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange(ranges_mc, cluster, [&](const IndexRange& range_mc, size_t worker) { diff --git a/util/threading_context.h b/util/threading_context.h index 41d0811..ac42526 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -97,6 +97,13 @@ class ThreadingArgs : public ArgsBase { struct ThreadingContext { explicit ThreadingContext(const ThreadingArgs& args); + // Returns a worker index compatible with those from `ParallelFor`, assuming + // the current thread is running on one thread per cluster, which happens + // when `ParallelismStrategy` is `kAcrossClusters`. + size_t Worker(size_t cluster_idx) const { + return cluster_idx * pools.MaxWorkersPerCluster(); + } + // Singleton; pass around a reference to reduce overhead. hwy::Profiler& profiler; @@ -158,7 +165,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, switch (parallelism) { case ParallelismStrategy::kNone: { - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); for (size_t task = 0; task < num_tasks; ++task) { func(task, worker); } @@ -173,7 +180,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, case ParallelismStrategy::kWithinCluster: { // Ensure the worker argument is unique across clusters, because it is // used for TLS indexing for example in profiler.h. - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); return ctx.pools.Cluster(pkg_idx, cluster_idx) .Run(0, num_tasks, [&](uint64_t task, size_t worker) { func(task, base + worker); @@ -193,8 +200,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, return ctx.pools.AllClusters(pkg_idx).Run( 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { - const size_t worker = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); func(task, worker); }); } From f10ac41a20bd31974b60253eb6c6b150442a63ac Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 9 Sep 2025 08:04:45 -0700 Subject: [PATCH 38/65] Added flash attention, with both a single-q function, and a register-tiled function. The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine. PiperOrigin-RevId: 804913784 --- BUILD.bazel | 23 ++ CMakeLists.txt | 3 + gemma/activations.h | 9 +- gemma/attention.cc | 27 +- gemma/attention.h | 8 + gemma/flash_attention.cc | 510 ++++++++++++++++++++++++++++++++++ gemma/flash_attention.h | 61 ++++ gemma/flash_attention_test.cc | 171 ++++++++++++ ops/ops-inl.h | 345 +++++++++++++++++++++++ 9 files changed, 1146 insertions(+), 11 deletions(-) create mode 100644 gemma/flash_attention.cc create mode 100644 gemma/flash_attention.h create mode 100644 gemma/flash_attention_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index dbe52b7..02c54bd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -117,6 +117,27 @@ cc_library( ], ) +cc_test( + name = "flash_attention_test", + srcs = ["gemma/flash_attention_test.cc"], + deps = [ + ":configs", + ":gemma_args", + ":gemma_lib", + ":kv_cache", + ":mat", + ":matmul", + ":ops", + ":threading_context", + ":weights", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "//compression:types", + "@highway//:hwy", + "@highway//:hwy_test_util", + ], +) + cc_test( name = "threading_test", srcs = ["util/threading_test.cc"], @@ -526,12 +547,14 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/attention.cc", + "gemma/flash_attention.cc", "gemma/gemma.cc", "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", "gemma/attention.h", + "gemma/flash_attention.h", "gemma/gemma.h", "gemma/vit.h", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bc0e80..cb2911f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,8 @@ set(SOURCES gemma/attention.h gemma/configs.cc gemma/configs.h + gemma/flash_attention.cc + gemma/flash_attention.h gemma/gemma_args.h gemma/gemma-inl.h gemma/gemma.cc @@ -216,6 +218,7 @@ set(GEMMA_TEST_FILES compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc + gemma/flash_attention_test.cc gemma/tensor_info_test.cc io/blob_store_test.cc io/fields_test.cc diff --git a/gemma/activations.h b/gemma/activations.h index 71523e4..9460d15 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -56,7 +56,11 @@ struct AttentionActivations { ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), - + q_T(MatFactory("q_T", layer_config.qkv_dim, + config.vocab_size == 0 + ? batch_size * layer_config.heads * 3 + : batch_size * layer_config.heads, + allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, @@ -90,11 +94,13 @@ struct AttentionActivations { // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); + q_T.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); + q_T.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); @@ -105,6 +111,7 @@ struct AttentionActivations { const ModelConfig& config; MatStorageT q; // query + MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT pre_att_rms_out; MatStorageT att; // attention vector diff --git a/gemma/attention.cc b/gemma/attention.cc index 31ed4d1..61d76ef 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -41,12 +41,16 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "gemma/flash_attention.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use. +constexpr int kUseOldAttention = 2; + // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, @@ -71,11 +75,11 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, } } -static void PositionalEncodingQK(float* qk, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, - hwy::Profiler& p, const size_t worker, - const size_t pos, const float mul = 1.0f) { +void PositionalEncodingQK(float* qk, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + hwy::Profiler& p, const size_t worker, + const size_t pos, const float mul) { const size_t qkv_dim = layer.layer_config.qkv_dim; const PostQKType& post_qk = layer.layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. @@ -165,8 +169,7 @@ void SingleDotSoftmaxWeightedSum( // The attention window usually starts at 0 unless `pos` is larger than // the attention window size, then it is `pos` - window_size + 1. -static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, - size_t layer_idx) { +size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { const size_t att_window_size = config.attention_window_sizes[layer_idx]; return pos - HWY_MIN(att_window_size - 1, pos); } @@ -314,7 +317,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, } PositionalEncodingQK(kv_f32, layer_idx, layer, activations, - env.ctx.profiler, worker, pos); + env.ctx.profiler, worker, pos, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); @@ -354,8 +357,12 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, (void)layer_config; // only used in HWY_DASSERT ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); - DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, - env.ctx); + if (flags & kUseOldAttention) { + DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, + env.ctx); + } else { + FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); + } SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index c69cc8f..a0af4ff 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -28,6 +28,14 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ + void PositionalEncodingQK(float* qk, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + hwy::Profiler& p, size_t worker, size_t pos, \ + float mul); \ + \ + size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \ + \ void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc new file mode 100644 index 0000000..40096d1 --- /dev/null +++ b/gemma/flash_attention.cc @@ -0,0 +1,510 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "util/threading_context.h" +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include "gemma/activations.h" +#include "gemma/configs.h" // kMaxQKVDim +#include "gemma/gemma.h" +#include "gemma/weights.h" +#include "util/threading.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/flash_attention.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "gemma/attention.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Transposes q into q_t. +// Both are 4D tensors stuffed into a 2-D MatPtrT. +// q has shape [batch, qbatch][head, qkv_dim]. +// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum +// possible consecutive elements have the same KV. +static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, + const size_t qbatch_size, ThreadingContext& ctx) { + static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); + const size_t num_heads = q.Cols() / q_t.Rows(); + const size_t batch_size = q.Rows() / qbatch_size; + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + PROFILER_ZONE3(ctx.profiler, worker, zone); + float* HWY_RESTRICT qt_row = q_t.Row(task); + for (size_t qi = 0; qi < qbatch_size; ++qi) + for (size_t h = 0; h < num_heads; ++h) { + for (size_t b = 0; b < batch_size; ++b) { + qt_row[(qi * num_heads + h) * batch_size + b] = + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task]; + } + } + }; + { + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(q_t.Rows(), ctx.pools, func); + } +} + +// Updates q in place for RMSNorm and positional encoding. +void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, + MatPtrT& q, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + ThreadingContext& ctx) { + static const auto zone = + ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); + const float query_scale = activations.query_scale; + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + PROFILER_ZONE3(ctx.profiler, worker, zone); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + for (size_t h = 0; h < layer.layer_config.heads; ++h) { + const size_t tq_idx = qbatch.Size() * task + qi; + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + task; + float* HWY_RESTRICT q_row = + q.Row(tq_idx) + h * layer.layer_config.qkv_dim; + // Apply rope and scaling to Q. + if (layer.query_norm_scale.HasPtr()) { + CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), q_row, + layer.layer_config.qkv_dim, ctx.profiler, worker); + }); + } + PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, + worker, pos, query_scale); + } + } + }; + { + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(num_tokens, ctx.pools, func); + } +} + +// Calculates the complete attention outputs for a single row of q. +void SingleFlashAttention(const size_t start_pos, const size_t last_pos, + const float* HWY_RESTRICT q, const MatPtrT& k, + const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + float* HWY_RESTRICT att_out, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention"); + PROFILER_ZONE3(p, worker, zone); + const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); + float m = Dot(q, k.Row(pos_mod), k.Cols()); + float d = 1.0f; + // This is just a copy of the first token. + MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + const size_t pos_mod = activations.div_seq_len.Remainder(pos); + float x = Dot(q, k.Row(pos_mod), k.Cols()); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + x = activations.config.att_cap * + std::tanh(x / activations.config.att_cap); + } + float m_new = std::max(m, x); + float scale = d * std::exp(m - m_new); + x = std::exp(x - m_new); + m = m_new; + d = scale + x; + float one_over_d = 1.0f / d; + x *= one_over_d; + scale *= one_over_d; + MulByConst(scale, att_out, v.Cols(), p, worker); + MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker); + } +} + +// Computes and returns a single vector of NF Q.K dot products, which represents +// the dot products of NF rows of Q for a single K timestep. +template > +VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, + const size_t k_pos, const MatPtrT& q, + const MatPtrT& k, hwy::Profiler& p, const size_t worker) { + hn::TFromD results[hn::MaxLanes(df)]; + for (size_t i = 0; i < hn::Lanes(df); ++i) { + results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); + } + return hn::LoadU(df, results); +} + +// Returns an 8xNF tile of Q.K dot products, in single precision. +// This is the result of NF rows of Q against 8 K timesteps, with positions +// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in +// consecutive elements, and other columns by adding q_stride. +template > +void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const size_t* k_pos, + hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, + VF& sum7) { + constexpr size_t kHTileSize = 8; + sum0 = hn::Zero(df); + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); + sum4 = hn::Zero(df); + sum5 = hn::Zero(df); + sum6 = hn::Zero(df); + sum7 = hn::Zero(df); + const float* HWY_RESTRICT k_row[kHTileSize]; + for (int i = 0; i < kHTileSize; ++i) { + k_row[i] = k.Row(k_pos[i]); + } + for (size_t i = 0; i < k.Cols(); ++i) { + VF q_vec = hn::Load(df, q); + VF k_0 = hn::Set(df, k_row[0][i]); + sum0 = hn::MulAdd(q_vec, k_0, sum0); + VF k_1 = hn::Set(df, k_row[1][i]); + sum1 = hn::MulAdd(q_vec, k_1, sum1); + VF k_2 = hn::Set(df, k_row[2][i]); + sum2 = hn::MulAdd(q_vec, k_2, sum2); + VF k_3 = hn::Set(df, k_row[3][i]); + sum3 = hn::MulAdd(q_vec, k_3, sum3); + VF k_4 = hn::Set(df, k_row[4][i]); + sum4 = hn::MulAdd(q_vec, k_4, sum4); + VF k_5 = hn::Set(df, k_row[5][i]); + sum5 = hn::MulAdd(q_vec, k_5, sum5); + VF k_6 = hn::Set(df, k_row[6][i]); + sum6 = hn::MulAdd(q_vec, k_6, sum6); + VF k_7 = hn::Set(df, k_row[7][i]); + sum7 = hn::MulAdd(q_vec, k_7, sum7); + q += q_stride; + } +} + +// Returns the element-wise maximum of 8 vectors, in a single vector. +template > +VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2, + const VF& x3, const VF& x4, const VF& x5, + const VF& x6, const VF& x7) { + VF m0 = hn::Max(x0, x1); + VF m1 = hn::Max(x2, x3); + VF m2 = hn::Max(x4, x5); + VF m3 = hn::Max(x6, x7); + m0 = hn::Max(m0, m1); + m2 = hn::Max(m2, m3); + return hn::Max(m0, m2); +} + +// Returns the element-wise sum of 8 vectors, in a single vector. +template > +VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, + const VF& x3, const VF& x4, const VF& x5, + const VF& x6, const VF& x7) { + VF sum0 = hn::Add(x0, x1); + VF sum1 = hn::Add(x2, x3); + VF sum2 = hn::Add(x4, x5); + VF sum3 = hn::Add(x6, x7); + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + return hn::Add(sum0, sum2); +} + +// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then +// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos]. +void TileFlashAttention( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, + const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, + const size_t min_last_pos, const size_t max_last_pos, + const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, const AttentionActivations& activations, + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); + PROFILER_ZONE3(p, worker, zone); + constexpr int kHTileSize = 8; + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + VI lasts = hn::LoadU(di, last_pos); + VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); + VF old_d = hn::Zero(df); + const float* HWY_RESTRICT qT_row = qT.Row(0); + const size_t qT_stride = qT.Stride(); + size_t position = start_pos; + while (position + kHTileSize - 1 <= min_last_pos) { + size_t k_pos[kHTileSize]; + for (size_t i = 0; i < kHTileSize; ++i) { + k_pos[i] = activations.div_seq_len.Remainder(position + i); + } + VF x0, x1, x2, x3, x4, x5, x6, x7; + QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3, + x4, x5, x6, x7); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); + x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); + x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap))); + x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap))); + x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap))); + x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap))); + } + VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); + m = hn::Max(old_m, m); + x0 = hn::Exp(df, x0 - m); + x1 = hn::Exp(df, x1 - m); + x2 = hn::Exp(df, x2 - m); + x3 = hn::Exp(df, x3 - m); + x4 = hn::Exp(df, x4 - m); + x5 = hn::Exp(df, x5 - m); + x6 = hn::Exp(df, x6 - m); + x7 = hn::Exp(df, x7 - m); + VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); + old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); + old_d = hn::Add(scale, old_d); + old_m = m; + VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); + scale = hn::Mul(scale, one_over_d); + x0 = hn::Mul(x0, one_over_d); + x1 = hn::Mul(x1, one_over_d); + x2 = hn::Mul(x2, one_over_d); + x3 = hn::Mul(x3, one_over_d); + x4 = hn::Mul(x4, one_over_d); + x5 = hn::Mul(x5, one_over_d); + x6 = hn::Mul(x6, one_over_d); + x7 = hn::Mul(x7, one_over_d); + MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, + att_out.Row(0), out_offsets, v.Cols(), p, worker); + position += kHTileSize; + } + while (position <= max_last_pos) { + size_t k_pos = activations.div_seq_len.Remainder(position); + VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + } + // Past the last position, x0 doesn't count. + auto mask = hn::Gt(hn::Set(di, position), lasts); + VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask), + std::numeric_limits::max() / 2.0f); + x0 = hn::Sub(x0, causal_offset); + VF m = hn::Max(old_m, x0); + x0 = hn::Exp(df, x0 - m); + VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); + old_m = m; + old_d = hn::Add(scale, x0); + VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); + x0 = hn::Mul(x0, one_over_d); + scale = hn::Mul(scale, one_over_d); + MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, + v.Cols(), p, worker); + ++position; + } +} + +// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] +// into a single output O[L,D]. +// Conventional attention first computes A[L,L] = Q . KT +// followed by A = softmax(A) (over invididual rows). +// Then A is multiplied by V to get O[L,D]. +// For each row of O, this takes a read of one row of Q L times, all of K, +// 3 write/reads of one row of A, read all of V, an read.write the one row of O +// L times. Ignoring the computation for now, and focusing just on memory, +// the one row of O takes L(4D+3) reads and L(D+3) writes. +// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. +// +// Flash attention fuses these operations together, and (where possible) +// computes NF rows of the result using 8 accumulator registers and two more to +// keep running results. NF is the number of float lanes in a register, being 16 +// for AVX3. The softmax is converted to streaming form using the +// algortihm from: +// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. +// Q is transposed to Q_T[D,L] to make the dot product computation efficient. +// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing +// reads of Q by 8 and reads of K by NF. The streaming softmax is computed +// entirely in registers, and a further NF registers to accumulate the results +// of the product of the softmax and V, reduce the number of reads of V by NF, +// and the reads/writes of O by 8. +// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, +// which on AVX3 is an overall reduction by about a factor of 10. +// +// A further complication is that real attention is not as simple as documented +// in the paper and above. There are multiple query heads, differing KV, and +// different sequence lengths, so a lot of the work in FlashAttention is making +// sure that a collection of q rows can use the TileFlashAttention path. +void FlashAttention(const size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + AttentionActivations& activations, QBatch& qbatch, + ThreadingContext& ctx) { + static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); + RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, + layer, activations, ctx); + const hwy::Divisor div_qbatch(qbatch.Size()); + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + + using DF = hn::ScalableTag; + const DF df; + constexpr size_t kVTileSize = hn::MaxLanes(df); + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); + const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); + const size_t total_tasks = token_batch * layer_config.heads; + // q has shape [batch, qbatch][head, qkv_dim]. + // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the + // maximum possible number of consecutive columns have the same KV matrices. + // Each thread will process a tile of NF columns of QT so the starting column + // index of QT is just the task index * kVTileSize. + TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); + const hwy::Divisor div_tokens(num_tokens); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); + + // For each head/token/query, compute fused flash Q.K, softmax and weighted V. + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + PROFILER_ZONE3(ctx.profiler, worker, zone); + // Offsets into original Q for each row in the tile. + uint32_t q_offsets[kVTileSize]; + // Offsets into att_out for each row in the tile. + uint32_t out_offsets[kVTileSize]; + // Start positions for each row in the tile. + size_t start_positions[kVTileSize]; + // Last positions for each row in the tile. Inclusive. + uint32_t last_pos[kVTileSize]; + // min and max last positions across all rows in the tile determines when + // TileFlashAttention switches to single vector mode to handle the + // ragged sequence lengths. + size_t min_last_pos = std::numeric_limits::max(); + size_t max_last_pos = 0; + // Indices into the qbatch.KV for each row in the tile. + size_t qi_indices[kVTileSize]; + // Indices into the kv_cache for each row in the tile. + size_t kv_offsets[kVTileSize]; + // first_task is [qbatch, head, token]. + const size_t first_task = task * kVTileSize; + const size_t last_task = first_task + kVTileSize - 1; + bool use_tile_attention = last_task < total_tasks; + for (size_t offset = 0; + offset < kVTileSize && first_task + offset < total_tasks; ++offset) { + const size_t batch_idx = div_tokens.Remainder(first_task + offset); + const size_t qh = div_tokens.Divide(first_task + offset); + const size_t head = activations.div_heads.Remainder(qh); + const size_t qi = activations.div_heads.Divide(qh); + const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi; + qi_indices[offset] = qi; + + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + start_positions[offset] = start_pos; + size_t last = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last) { + // last_pos in QDotK and WeightedSumV is inclusive. + last = prefix_end - 1; + } + last_pos[offset] = last; + min_last_pos = HWY_MIN(min_last_pos, last); + max_last_pos = HWY_MAX(max_last_pos, last); + q_offsets[offset] = + activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0); + out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - + activations.att_out.Row(0); + const size_t kv_index = head / kHeadGroups; + const size_t head_offset = kv_index * qkv_dim * 2; + kv_offsets[offset] = layer_idx * cache_layer_size + head_offset; + // If any of the parameters in this if statement differ within this task, + // then we can't use TileFlashAttention. TileFlashAttention requires that + // all rows in the tile have the same K and V matrices, and Q starts at + // the same position. The end positions do not have to be the equal. + if (start_positions[offset] != start_positions[0] || + qi_indices[offset] != qi_indices[0] || + kv_offsets[offset] != kv_offsets[0]) { + use_tile_attention = false; + } + } + for (size_t offset = 0; + offset < kVTileSize && first_task + offset < total_tasks; ++offset) { + auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim, + kv_cache.Stride()); + if (use_tile_attention) { + // To avoid duplicating the code to setup K and V, the call to + // TileFlashAttention is inside the loop over tasks, even thought it + // handles all rows in the task at once. + StridedView qT = + StridedView(activations.q_T.Row(0) + first_task, kVTileSize, + activations.q_T.Stride()); + TileFlashAttention( + activations.q, q_offsets, qT, k, start_positions[offset], last_pos, + min_last_pos, max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, worker); + break; + } else { + SingleFlashAttention(start_positions[offset], last_pos[offset], + activations.q.Row(0) + q_offsets[offset], k, v, + layer_idx, layer, activations, + activations.att_out.Row(0) + out_offsets[offset], + ctx.profiler, worker); + } + } + }; + + { + PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(num_thread_tasks, ctx.pools, func); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h new file mode 100644 index 0000000..b505d6f --- /dev/null +++ b/gemma/flash_attention.h @@ -0,0 +1,61 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ + +// Declares FlashAttention for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \ + MatPtrT& q, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const float* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + float* HWY_RESTRICT att_out, hwy::Profiler& p, \ + size_t worker); \ + \ + void FlashAttention(size_t num_tokens, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION) + +#undef GEMMA_DECL_FLASH_ATTENTION + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc new file mode 100644 index 0000000..efb210e --- /dev/null +++ b/gemma/flash_attention_test.cc @@ -0,0 +1,171 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "compression/types.h" +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "gemma/weights.h" +#include "ops/matmul.h" +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include +#include + +#include // std::max +#include // std::abs +#include + +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/flash_attention_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "gemma/attention.h" +#include "gemma/configs.h" +#include "gemma/flash_attention.h" +#include "ops/matvec-inl.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +using FloatPtr = hwy::AlignedFreeUniquePtr; + +void SetMat(const size_t offset, MatPtrT& mat) { + const size_t kOuter = mat.Extents().rows; + const size_t kInner = mat.Extents().cols; + const float i_scale = 1.0f / kInner; + const float j_scale = 1.0f / kOuter; + for (size_t i = 0; i < kOuter; ++i) { + float* row = mat.Row(i); + for (size_t j = 0; j < kInner; ++j) { + row[j] = + static_cast((i * kInner * i_scale + (j + offset) * j_scale)); + } + } +} + +std::unique_ptr> MakeCopyOfMat(const MatPtrT& mat, + const Allocator& allocator) { + auto copy = std::make_unique>("TestMat", mat.Extents(), + allocator, MatPadding::kOdd); + CopyMat(mat, *copy); + return copy; +} + +void AssertClose(const MatPtrT& a, const MatPtrT& b) { + // Avoid comparing the padding bytes, which are uninitialized. + for (size_t r = 0; r < a.Rows(); ++r) { + const float* HWY_RESTRICT a_row = a.Row(r); + const float* HWY_RESTRICT b_row = b.Row(r); + for (size_t c = 0; c < a.Cols(); ++c) { + float rel_abs_delta = std::abs(a_row[c] - b_row[c]); + if (rel_abs_delta > 0.0f) { + rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); + } + EXPECT_LT(rel_abs_delta, 1e-5) + << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," + << c << "]=" << b_row[c]; + } + } +} + +void TestAttention() { + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + // hwy::ThreadPool& pool = ctx.pools.Pool(); + constexpr size_t kOuter = 1024; + constexpr size_t kInner = 256; + ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); + TensorInfoRegistry tensor_info_registry(config); + const LayerConfig& layer_config = config.layer_configs[0]; + const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); + InferenceArgs inference_args; + RuntimeConfig runtime_config; + KVCache kv_cache(config, inference_args, ctx.allocator); + MatMulEnv env(ctx); + Activations activations(config, runtime_config.prefill_tbatch_size, + kv_cache.SeqLen(), env.ctx, env.row_ptrs); + std::vector tokens(kOuter); + std::iota(tokens.begin(), tokens.end(), 1); + PromptTokens prompt(tokens); + AllQueries all_queries(hwy::Span(&prompt, 1), + hwy::Span(&kv_cache, 1)); + QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); + const size_t batch_size = kOuter; + std::vector> row_ptrs; + AttentionActivations attention(config, layer_config, batch_size, kOuter, + ctx.allocator, row_ptrs); + const size_t qkv_dim = layer_config.qkv_dim; + ASSERT_EQ(qkv_dim, kInner); + const hwy::Divisor div_qbatch(qbatch.Size()); + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + const size_t seq_len = + static_cast(attention.div_seq_len.GetDivisor()); + auto& kvc = qbatch.KV(0).kv_cache; + for (size_t h = 0; h < layer_config.heads; ++h) { + // Make strided views into the kv cache for + // this query and head. + const size_t head_offset = (h / kHeadGroups) * qkv_dim * 2; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kvc.Row(0) + head_offset, kvc.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride()); + SetMat(h + layer_config.heads, k); + SetMat(h + layer_config.heads * 2, v); + } + SetMat(1, attention.q); + DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx); + // Copy the output to saved_att to allow for comparison. + auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); + SetMat(1, attention.q); + FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx); + AssertClose(attention.att_out, *saved_att); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(FlashAttentionTest); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0c6bd50..cfd85ae 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -613,6 +613,351 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( }); } +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( + DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, + VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, + VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) { + sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); + sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); + sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); + sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); + sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); + sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); + sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); + sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); + sum8 = hn::MulAdd(common, hn::Set(df, split.raw[8]), sum8); + sum9 = hn::MulAdd(common, hn::Set(df, split.raw[9]), sum9); + sum10 = hn::MulAdd(common, hn::Set(df, split.raw[10]), sum10); + sum11 = hn::MulAdd(common, hn::Set(df, split.raw[11]), sum11); + sum12 = hn::MulAdd(common, hn::Set(df, split.raw[12]), sum12); + sum13 = hn::MulAdd(common, hn::Set(df, split.raw[13]), sum13); + sum14 = hn::MulAdd(common, hn::Set(df, split.raw[14]), sum14); + sum15 = hn::MulAdd(common, hn::Set(df, split.raw[15]), sum15); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& sum4, VF& sum5, VF& sum6, + VF& sum7) { + sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); + sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); + sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); + sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); + sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); + sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); + sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); + sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, + VF& sum3) { + sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); + sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); + sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); + sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); +} + +// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows +// of V by the corresponding values in c0-c7 and adds them to NF rows of out, +// after first prescaling out by scale. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( + DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, + const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); + PROFILER_ZONE3(p, worker, zone); + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df); + + size_t i = 0; + while (i + NF <= size) { + if HWY_LANES_CONSTEXPR (NF == 16) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + VF out8, out9, out10, out11, out12, out13, out14, out15; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out8 = hn::Load(df, out + i + out_offsets[8]); + out9 = hn::Load(df, out + i + out_offsets[9]); + out10 = hn::Load(df, out + i + out_offsets[10]); + out11 = hn::Load(df, out + i + out_offsets[11]); + out12 = hn::Load(df, out + i + out_offsets[12]); + out13 = hn::Load(df, out + i + out_offsets[13]); + out14 = hn::Load(df, out + i + out_offsets[14]); + out15 = hn::Load(df, out + i + out_offsets[15]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); + out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); + out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); + out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); + out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); + out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); + out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); + out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + hn::Store(out8, df, out + i + out_offsets[8]); + hn::Store(out9, df, out + i + out_offsets[9]); + hn::Store(out10, df, out + i + out_offsets[10]); + hn::Store(out11, df, out + i + out_offsets[11]); + hn::Store(out12, df, out + i + out_offsets[12]); + hn::Store(out13, df, out + i + out_offsets[13]); + hn::Store(out14, df, out + i + out_offsets[14]); + hn::Store(out15, df, out + i + out_offsets[15]); + } + if HWY_LANES_CONSTEXPR (NF == 8) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + } + if HWY_LANES_CONSTEXPR (NF == 4) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd4(df, x0, c0, out0, out1, out2, out3); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd4(df, x1, c1, out0, out1, out2, out3); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd4(df, x2, c2, out0, out1, out2, out3); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd4(df, x3, c3, out0, out1, out2, out3); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd4(df, x4, c4, out0, out1, out2, out3); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd4(df, x5, c5, out0, out1, out2, out3); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd4(df, x6, c6, out0, out1, out2, out3); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd4(df, x7, c7, out0, out1, out2, out3); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + } + i += NF; + } + const size_t remaining = size - i; + HWY_DASSERT(remaining == 0); +} + +// Prescales NF rows of out by scale, then multiplies 1 row of V by the +// corresponding values in c0 and adds them to the NF rows of out. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( + DF df, const VF scale, const VF c0, const MatPtrT& v, + const size_t pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); + PROFILER_ZONE3(p, worker, zone); + namespace hn = hwy::HWY_NAMESPACE; + const size_t NF = hn::MaxLanes(df); + + size_t i = 0; + while (i + NF <= size) { + if constexpr (NF == 16) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + VF out8, out9, out10, out11, out12, out13, out14, out15; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out8 = hn::Load(df, out + i + out_offsets[8]); + out9 = hn::Load(df, out + i + out_offsets[9]); + out10 = hn::Load(df, out + i + out_offsets[10]); + out11 = hn::Load(df, out + i + out_offsets[11]); + out12 = hn::Load(df, out + i + out_offsets[12]); + out13 = hn::Load(df, out + i + out_offsets[13]); + out14 = hn::Load(df, out + i + out_offsets[14]); + out15 = hn::Load(df, out + i + out_offsets[15]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); + out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); + out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); + out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); + out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); + out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); + out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); + out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + hn::Store(out8, df, out + i + out_offsets[8]); + hn::Store(out9, df, out + i + out_offsets[9]); + hn::Store(out10, df, out + i + out_offsets[10]); + hn::Store(out11, df, out + i + out_offsets[11]); + hn::Store(out12, df, out + i + out_offsets[12]); + hn::Store(out13, df, out + i + out_offsets[13]); + hn::Store(out14, df, out + i + out_offsets[14]); + hn::Store(out15, df, out + i + out_offsets[15]); + } else if constexpr (NF == 8) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + } else if constexpr (NF == 4) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd4(df, x0, c0, out0, out1, out2, out3); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + } else { + HWY_DASSERT(false); + } + i += NF; + } + const size_t remaining = size - i; + HWY_DASSERT(remaining == 0); +} + // See below for a specialized version for top-1 sampling. // TODO: support bf16 logits using Decompress2. static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, From 9457258330a6c244045b852abdca1afe9a19658c Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 22:09:09 -0700 Subject: [PATCH 39/65] Refactor MatMul to accept views in the kernel functions Make arg order consistent. Move StridedView into mat.h. Add view support to RowPtrs. PiperOrigin-RevId: 805197381 --- ops/matmul-inl.h | 193 ++++++++++++++++++++++++++--------------------- ops/matmul.h | 61 ++++----------- util/mat.h | 84 +++++++++++++++++---- 3 files changed, 192 insertions(+), 146 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index bf7bd68..3a20690 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -148,21 +148,21 @@ class MMStoreHorizontalSumsIntoC { } } - // Scales the dot-product terms and adds bias (if present) and stores the - // four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is - // `MMSetC`, the vectors are written as-is (first call, or small K). - // Otherwise, they are partial sums and are accumulated into C. - template , class Tag, class CRows> - HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, - const size_t row_c, const size_t col_c, - const MMArgs& args, CRows C_rows) const { - const V4 vscale = hn::Set(d4, args.scale); + // Scales the dot-product terms plus `add` (if non-null) and stores the four + // 4-wide vectors to `C` starting at row 0, column 0. If `tag` is `MMSetC`, + // the vectors are written as-is (first call, or small K). Otherwise, they + // are partial sums and are accumulated into C. + template , class Tag, class CView> + HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, + const float scale, const float* HWY_RESTRICT add, + const size_t imc, Tag tag, CView C_rows) const { + const V4 vscale = hn::Set(d4, scale); HWY_ALIGN static constexpr float kZero[4] = {}; - const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c); + const V4 vadd = hn::Load(d4, add ? add : kZero); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows); } private: @@ -199,13 +199,13 @@ class MMStoreHorizontalSumsIntoC { } template , - class Tag, typename TC> + class Tag, class CView> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, - VF4 vadd, Tag, RowPtrs C_rows, - const size_t row_c, - const size_t col_c) { + VF4 vadd, Tag, const size_t imc, + CView C_view) { if constexpr (kRow < kRowsAC) { - TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; + using TC = hwy::RemoveCvRef; + TC* HWY_RESTRICT pos = C_view.Row(imc + kRow); const hn::Rebind dc4; if constexpr (hwy::IsSame()) { vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value @@ -234,7 +234,7 @@ class MMDecompress { // Neither A nor B require padding because `LoopKC` handles remainders. if constexpr (hwy::IsSame()) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); + return StridedViewBF(B, row_b, range_kc.begin(), range_kc.Num()); } const PackedSpan B_span = B.PaddedSpan(); @@ -264,7 +264,7 @@ class MMDecompress { if constexpr (IsBF16()) { // We can use a view, regardless of columns/padding, because // `MMKernel::LoopKC` supports non-vector multiples. - return View(A, 0, 0, A.Cols()); + return StridedViewBF(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. We also only @@ -277,15 +277,6 @@ class MMDecompress { } private: - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - // Decompresses all `M x K` from `A` into padded BF16 `A_view`. static HWY_NOINLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view, @@ -402,26 +393,26 @@ class MMKernel { kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); public: - // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` - // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. + // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is + // `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start + // at row/col 0. `CView` is either `RowPtrs` or `StridedView`. // Called by B3A2C0 and by callers that hoist `A_view`. - template + template static HWY_INLINE void A2C0(const StridedViewBF A_view, const StridedViewBF B_view, size_t mr, - const IndexRange& range_mc, const size_t row_b, - size_t kc, Tag tag, const MMArgs& args, - CRows C_rows) { + const IndexRange& range_mc, size_t kc, + const float scale, const float* HWY_RESTRICT add, + Tag tag, CView C_view) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); - const size_t row0 = range_mc.begin(); + const size_t mc = range_mc.Num(); size_t imc = 0; // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); } return; } @@ -430,13 +421,11 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); } return; } @@ -444,18 +433,17 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); imc += 1; } HWY_DASSERT(imc == mc); @@ -466,11 +454,11 @@ class MMKernel { // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by // `ForeachKC` and when there is only a single KC task. - template + template static void B3A2C0(const StridedViewBF A, const MatPtrT& B, - const MMArgs& args, const IndexRange& range_mc, - const IndexRange& range_kc, const IndexRange& range_nc, - size_t mr, Tag out_tag, CRows C_rows) { + const IndexRange& range_mc, const IndexRange& range_kc, + const IndexRange& range_nc, const MMArgs& args, + Tag out_tag, RowPtrs C) { HWY_ALIGN BF16 B_storage[B_storage_max]; const size_t kc = range_kc.Num(); @@ -482,24 +470,28 @@ class MMKernel { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = + const StridedViewBF B_view = MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); - A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); + const RowPtrs C_view = C.View(range_mc.begin(), row_b); + const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr; + A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag, + C_view); } } - template + template static void ForeachKC(const StridedViewBF A, const MatPtrT& B, - const MMArgs& args, const IndexRange& range_mc, + const IndexRange& range_mc, const IndexRangePartition& ranges_kc, - const IndexRange& range_nc, size_t mr, CRows C_rows) { + const IndexRange& range_nc, const MMArgs& args, + RowPtrs C) { // Peel off the first iteration of the kc loop: avoid zero-initializing `C` // by writing directly into it, and later accumulating into it. ranges_kc.VisitFirst([&](const IndexRange& range_kc) { - B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C); }); ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { - B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C); }); } @@ -593,19 +585,20 @@ class MMKernel { // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). - // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. - // `A` and `B` are always BF16, `C` can be F32 or BF16. - template + // Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0. + // `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also + // relative to the C column. + template static HWY_INLINE void LoopKC(const StridedViewBF A_view, - const StridedViewBF B_view, size_t row_ac, - size_t imc, size_t col_c, size_t kc, Tag tag, - const MMArgs& args, CRows C_rows) { + const StridedViewBF B_view, size_t imc, + size_t kc, const float scale, + const float* HWY_RESTRICT add, Tag tag, + CView C_view) { const hn::ScalableTag dbf; using VBF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_DASSERT(kRowsAC <= kMaxMR); - HWY_DASSERT(col_c % kNR == 0); // Rows are aligned to `kMaxMR`, except for the last tile of A. // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. @@ -784,7 +777,7 @@ class MMKernel { hn::Vec sum0, sum1, sum2, sum3; horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); - horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows); + horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view); } }; @@ -884,7 +877,7 @@ class MMLoops { // or with the best config. template static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, args.env.ctx.Worker(args.options.cluster_idx), zone); @@ -892,7 +885,7 @@ class MMLoops { DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { - Loop(order, parallel, A, B, C_rows, args); + Loop(order, parallel, A, B, C, args); }); }); } @@ -904,11 +897,11 @@ class MMLoops { return HWY_MAX(kNR, line_bytes / sizeof_TC); } - // Single M and K ranges, parallel N. Fills all of C directly. + // Single M and K ranges, parallel N. template static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); @@ -932,10 +925,21 @@ class MMLoops { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = + const StridedViewBF B_view = MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), - args, C_rows); + const RowPtrs C_view = C.View(range_M.begin(), row_b); + const float* HWY_RESTRICT add = + args.add ? args.add + row_b : nullptr; + + MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add, + MMSetC(), C_view); + } + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_M, range_nc, C2, worker); + } } }); } @@ -944,7 +948,7 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -955,17 +959,24 @@ class MMLoops { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, - range_nc, args.mr, C_rows); + MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, + range_nc, args, C); + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_mc, range_nc, C2, worker); + } + } }); } // Parallel loops over mc/nc blocks of M/range_n, single K. - // Fills `mc x nc` sections of C directly, in parallel. + // Fills `mc x nc` sections of C. template static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); const IndexRange& range_K = args.ranges_kc.Range(0); @@ -976,17 +987,24 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr, - MMSetC(), C_rows); + MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(), + C); + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_mc, range_nc, C2, worker); + } + } }); } - // Parallel loops over mc/nc blocks of M/range_np, sequential K. + // Parallel loops over mc/nc blocks of M/range_n, sequential K. // Accumulates into `mc x nc` sections of `C`. template static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); parallel.ForRangesMC_NC( @@ -995,8 +1013,15 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc, - args.mr, C_rows); + MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args, + C); + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_mc, range_nc, C2, worker); + } + } }); } }; // MMLoops diff --git a/ops/matmul.h b/ops/matmul.h index a85d192..93e7b04 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -60,54 +60,6 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink? // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; -// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. -// Also used to decompress B, hence non-const. -#pragma pack(push, 1) // power of two size -template -class StridedView { - public: - StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), - cols_(static_cast(cols)), - stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); - } - - T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } - size_t Cols() const { return static_cast(cols_); } - - size_t Stride() const { return static_cast(stride_); } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - StridedView View(size_t r, size_t c, size_t cols) const { - HWY_DASSERT(c < Cols()); - HWY_DASSERT(cols <= Cols() - c); - return StridedView(Row(r) + c, cols, stride_); - } - - private: - T* HWY_RESTRICT row0_; - uint32_t cols_; - uint32_t stride_; -}; -#pragma pack(pop) - -using StridedViewBF = StridedView; -using StridedViewD = StridedView; - -using MMFused = std::function; - -struct MMOptions { - uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. - ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; - - MMFused fused; -}; - // Policy classes for parallelism, implementing some of `ParallelismStrategy`. struct MMParallelNone { @@ -735,6 +687,19 @@ struct MatMulEnv { std::vector> row_ptrs; }; +// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols) +// that this thread has just filled, a view into a second tile (only for the +// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`). +using MMFused = std::function; + +struct MMOptions { + uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. + ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; + + MMFused fused; // called if non-null and `TC` is BF16. +}; + // Arguments to MatMul() that are independent of the A/B/C types. Reduces // register pressure compared to individual values/references. Also used for // passing through `DispatchOrder`. diff --git a/util/mat.h b/util/mat.h index c8a4617..4360b69 100644 --- a/util/mat.h +++ b/util/mat.h @@ -38,17 +38,27 @@ namespace gcpp { template class RowPtrs { public: - RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {} + RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {} + + RowPtrs View(size_t r, size_t c) { + RowPtrs view(row_ptrs_); + view.r0_ = static_cast(r); + view.c0_ = static_cast(c); + return view; + } T* HWY_RESTRICT Row(size_t row_idx) const { - return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]); + return HWY_RCAST_ALIGNED(T*, row_ptrs_[r0_ + row_idx]) + c0_; } - T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); } private: uint8_t** row_ptrs_; + uint32_t r0_; + uint32_t c0_; }; +using RowPtrsBF = RowPtrs; + // Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector // or matrix). Base class of the non-type-erased `MatPtrT`. Use this class // to store hetereogeneous tensor references in a vector. @@ -349,12 +359,12 @@ RowPtrs GetOrSetTempRowPtrs( template decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, Args&&... args) { -#if GEMMA_ENABLE_NUQ - if (base->GetType() == Type::kNUQ) { - const MatPtrT mat(*base); - return func(&mat, std::forward(args)...); + if constexpr (GEMMA_ENABLE_NUQ) { + if (base->GetType() == Type::kNUQ) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); + } } -#endif // GEMMA_ENABLE_NUQ if (base->GetType() == Type::kF32) { const MatPtrT mat(*base); @@ -376,13 +386,13 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, const Func& func, Args&&... args) { HWY_DASSERT(base1->GetType() == base2->GetType()); -#if GEMMA_ENABLE_NUQ - if (base1->GetType() == Type::kNUQ) { - const MatPtrT mat1(*base1); - const MatPtrT mat2(*base2); - return func(&mat1, &mat2, std::forward(args)...); + if constexpr (GEMMA_ENABLE_NUQ) { + if (base1->GetType() == Type::kNUQ) { + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); + } } -#endif // GEMMA_ENABLE_NUQ if (base1->GetType() == Type::kF32) { const MatPtrT mat1(*base1); @@ -508,5 +518,51 @@ class MatFactory { MatPadding padding_; }; +// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. +// Also used to decompress B, hence non-const. +#pragma pack(push, 1) // power of two size +template +class StridedView { + public: + StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), + cols_(static_cast(cols)), + stride_(static_cast(stride)) { + HWY_DASSERT(stride >= cols); + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + StridedView(const MatPtrT& mat, size_t r, size_t c, size_t cols) + : StridedView(const_cast(mat.Row(r)) + c, cols, mat.Stride()) { + HWY_DASSERT(c < mat.Cols()); + HWY_DASSERT(cols <= mat.Cols() - c); + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + StridedView View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < Cols()); + HWY_DASSERT(cols <= Cols() - c); + return StridedView(Row(r) + c, cols, stride_); + } + + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } + size_t Cols() const { return static_cast(cols_); } + + size_t Stride() const { return static_cast(stride_); } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + } + + private: + T* HWY_RESTRICT row0_; + uint32_t cols_; + uint32_t stride_; +}; +#pragma pack(pop) + +using StridedViewBF = StridedView; +using StridedViewD = StridedView; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ From ba6131311a444d23d6aad2c8ac3acf9e5b406ca4 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 10 Sep 2025 05:32:03 -0700 Subject: [PATCH 40/65] Fix gemma_batch_bench for flash attention q_T rows do not change. Also repeat prefill to reflect perf after autotuning. PiperOrigin-RevId: 805319377 --- evals/gemma_batch_bench.cc | 5 +++++ gemma/activations.h | 2 +- util/mat.h | 5 ++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 135c2bb..3ffa858 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -98,6 +98,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } + // Run again: prefill will be faster due to autotuning. Fewer decode steps + // because those are already fast. + s_env->SetMaxGeneratedTokens(3); + responses = BatchGemmaReply(inputs); + PROFILER_PRINT_RESULTS(); } } // namespace diff --git a/gemma/activations.h b/gemma/activations.h index 9460d15..f474c84 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -100,7 +100,7 @@ struct AttentionActivations { void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); - q_T.OverrideRows(batch_size); + // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); diff --git a/util/mat.h b/util/mat.h index 4360b69..c2427e5 100644 --- a/util/mat.h +++ b/util/mat.h @@ -186,7 +186,10 @@ class MatPtr : public IFields { // will return this value. Used to set the actual number of rows for // activations preallocated according to the batch size. void OverrideRows(size_t rows) { - HWY_ASSERT(rows <= private_rows_); + if (HWY_UNLIKELY(rows > private_rows_)) { + HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows, + private_rows_); + } override_rows_ = static_cast(rows); } From 2695aab5d2034e8d705dda9943b91d21acfcaed5 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 10 Sep 2025 07:25:07 -0700 Subject: [PATCH 41/65] Temporarily disable flash pending msan fix PiperOrigin-RevId: 805350234 --- gemma/attention.cc | 5 +---- gemma/configs.h | 5 +++-- gemma/gemma.cc | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index 61d76ef..e894981 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -48,9 +48,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use. -constexpr int kUseOldAttention = 2; - // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, @@ -357,7 +354,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, (void)layer_config; // only used in HWY_DASSERT ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); - if (flags & kUseOldAttention) { + if (flags & kAttentionUseOld) { DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); } else { diff --git a/gemma/configs.h b/gemma/configs.h index e4a26b8..e02645b 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -32,8 +32,9 @@ namespace gcpp { -static constexpr size_t kMaxConv1DWidth = 4; -static constexpr size_t kMaxQKVDim = 1024; +HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; + +HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class PromptWrapping { diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 785bd87..05583b3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -73,9 +73,9 @@ void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { + // TODO: remove flag to enable FlashAttention. GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, - /*flags=*/0); + env, kAttentionUseOld); } } From c9b8479f7d1dee327ce03abb829adb40c9512865 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Fri, 12 Sep 2025 07:47:36 -0700 Subject: [PATCH 42/65] Added zero-initialization to att_out. Re-enabled flash attention when HWY_NATIVE_DOT_BF16 is not available. PiperOrigin-RevId: 806284756 --- gemma/flash_attention.cc | 5 +++++ gemma/gemma.cc | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 40096d1..ba1de3e 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -256,6 +256,11 @@ void TileFlashAttention( using DI = hn::ScalableTag; const DI di; using VI = hn::Vec; + const int kVTileSize = hn::MaxLanes(df); + for (int i = 0; i < kVTileSize; ++i) { + hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], + v.Cols() * sizeof(att_out.Row(0)[0])); + } VI lasts = hn::LoadU(di, last_pos); VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); VF old_d = hn::Zero(df); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 05583b3..778ecc6 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -75,7 +75,7 @@ void Attention(LayerAttentionType type, const size_t num_tokens, if (type == LayerAttentionType::kGemma) { // TODO: remove flag to enable FlashAttention. GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, kAttentionUseOld); + env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); } } From 59db30e209ba5281bb92263b691ebbf5645ffb09 Mon Sep 17 00:00:00 2001 From: Charles Zhao Date: Sun, 14 Sep 2025 16:26:55 -0700 Subject: [PATCH 43/65] add const restriction for benchmark_helper.cc, and paligemma_helper.cc to remove a few uncessary copies. PiperOrigin-RevId: 807004597 --- evals/benchmark_helper.cc | 15 ++++++--------- evals/benchmark_helper.h | 4 ++-- paligemma/paligemma_helper.cc | 3 +-- python/gemma_py.cc | 6 +++--- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index e9fdafb..bd53845 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -137,24 +137,21 @@ std::vector GemmaEnv::BatchQueryModel( return res; } -QueryResult GemmaEnv::QueryModel(std::string& input) { +QueryResult GemmaEnv::QueryModel(const std::string& input) { const std::vector prompt = WrapAndTokenize(input); return QueryModel(prompt); } std::vector GemmaEnv::BatchQueryModel( const std::vector& inputs) { - std::vector> prompts; - prompts.reserve(inputs.size()); - for (auto& input : inputs) { - std::string mutable_prompt = input; - prompts.push_back(WrapAndTokenize(mutable_prompt)); - } std::vector prompt_vector; - prompt_vector.reserve(prompts.size()); - for (auto& prompt : prompts) { + prompt_vector.reserve(inputs.size()); + + for (auto& input : inputs) { + std::vector prompt = WrapAndTokenize(input); prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); } + QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size()); return BatchQueryModel(prompt_span); } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 261daa4..81ccde6 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -68,7 +68,7 @@ class GemmaEnv { return tokens; } - std::vector WrapAndTokenize(std::string& input) const { + std::vector WrapAndTokenize(const std::string& input) const { return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(), gemma_.Config().wrapping, 0, input); } @@ -87,7 +87,7 @@ class GemmaEnv { const QueriesPromptTokens& queries_prompt, const hwy::Span& prefix_end = hwy::Span()); // Adds turn structure to input, tokenizes and calls the above overload. - QueryResult QueryModel(std::string& input); + QueryResult QueryModel(const std::string& input); std::vector BatchQueryModel( const std::vector& inputs); diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index 449ee00..c32e925 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -43,8 +43,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { return true; }; - std::string mutable_prompt = prompt_text; - std::vector tokens = env_->WrapAndTokenize(mutable_prompt); + std::vector tokens = env_->WrapAndTokenize(prompt_text); tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); RuntimeConfig runtime_config = {.max_generated_tokens = 512, diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 2e39f68..1bab194 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -52,7 +52,7 @@ class GemmaModel { // Generates a single example, given a prompt and a callback to stream the // generated tokens. - void GenerateEx(std::string prompt, gcpp::StreamFunc stream, + void GenerateEx(const std::string& prompt, gcpp::StreamFunc stream, size_t max_generated_tokens, float temperature, float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) { std::vector prompt_tokens = env_.WrapAndTokenize(prompt); @@ -75,7 +75,7 @@ class GemmaModel { } // Generates a single example, given a prompt, and returns the result. - std::string Generate(std::string prompt, size_t max_generated_tokens, + std::string Generate(const std::string& prompt, size_t max_generated_tokens, float temperature, float /*seed*/, const std::vector& accept, const std::vector& end) { @@ -192,7 +192,7 @@ class GemmaModel { // Generates a response to the given prompt, using the last set image. // Uses the prompt_tokens if provided, otherwise tokenizes the prompt string. std::pair> GenerateWithImage( - std::string prompt, size_t max_generated_tokens, float temperature, + const std::string& prompt, size_t max_generated_tokens, float temperature, float /*seed*/, gcpp::AcceptFunc accept, std::vector prompt_tokens) { if (!image_tokens_) throw std::invalid_argument("No image set."); const gcpp::Gemma& model = *env_.GetGemma(); From f3bc1c17dad0754fe6f8023a994f6524374b14f9 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 15 Sep 2025 10:25:59 -0700 Subject: [PATCH 44/65] 1.03x speedup: fused FFN matmul-inl: support CView=StridedView or RowPtrs; rename to C_MC_NC matmul.cc: Allow 1 more rep for MC/NC to allow half-sized tiles, which helps. PiperOrigin-RevId: 807291701 --- evals/gemma_batch_bench.cc | 6 +- gemma/configs.h | 4 + gemma/gemma-inl.h | 61 ++++-- ops/matmul-inl.h | 404 ++++++++++++++++++++++--------------- ops/matmul.cc | 44 ++-- ops/matmul.h | 126 ++++++++---- ops/matmul_static-inl.h | 8 + ops/matmul_static.h | 14 +- ops/matmul_test.cc | 47 ++++- ops/ops-inl.h | 8 + util/mat.h | 13 +- 11 files changed, 488 insertions(+), 247 deletions(-) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 3ffa858..ff81671 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -37,7 +37,6 @@ class GemmaBatchBench : public ::testing::Test { protected: std::vector BatchGemmaReply( const std::vector& inputs) { - s_env->SetMaxGeneratedTokens(24); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 2; std::vector replies; @@ -92,15 +91,18 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { inputs.push_back(questions[qpos++]); if (qpos == questions.size()) qpos = 0; } + s_env->SetMaxGeneratedTokens(24); std::vector responses = BatchGemmaReply(inputs); for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); ++i) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } + PROFILER_PRINT_RESULTS(); + // Run again: prefill will be faster due to autotuning. Fewer decode steps // because those are already fast. - s_env->SetMaxGeneratedTokens(3); + s_env->SetMaxGeneratedTokens(2); responses = BatchGemmaReply(inputs); PROFILER_PRINT_RESULTS(); diff --git a/gemma/configs.h b/gemma/configs.h index e02645b..275f374 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -36,6 +36,10 @@ HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; +#ifndef GEMMA_FUSED_FFN +#define GEMMA_FUSED_FFN 1 +#endif // !GEMMA_FUSED_FFN + // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class PromptWrapping { GEMMA_IT, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index cb7ae6a..a7f1b01 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -43,6 +43,7 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +// For use by Vit even if !GEMMA_FUSED_FFN. template void Activation(ActivationType activation, T1* HWY_RESTRICT c1, const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, @@ -64,7 +65,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1, }); } -// No C2 multiplier. +// No C2 multiplier - used by Vit. template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, @@ -80,6 +81,34 @@ void ActivationBatched( }); } +#if GEMMA_FUSED_FFN + +// Called during `TwoMatMul`. +static inline void Activation(ActivationType activation, const RowPtrsBF C1, + const IndexRange range_r, + const IndexRange range_c, const StridedViewBF C2, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.ActivationFused"); + PROFILER_ZONE3(p, worker, zone); + + const size_t cols = range_c.Num(); + HWY_DASSERT(C2.Cols() == cols); + + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + // ActivationType::Gelu + // Gated: Gelu(c1) * c2. + for (size_t ir = 0; ir < range_r.Num(); ++ir) { + Decompress1AndCompressInplace( + DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir), + [](DF df, VF v1, VF v2) + HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); + } +} + +#else + template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, @@ -102,6 +131,8 @@ HWY_NOINLINE void ActivationBatched( } } +#endif // GEMMA_FUSED_FFN + template HWY_NOINLINE void ResidualConnection(const MatPtrT& other, MatPtrT& HWY_RESTRICT x, @@ -126,28 +157,32 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive); PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); const LayerConfig& layer_config = layer.layer_config; - const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; - const bool add_bias = layer_config.ff_biases; - const float* bias1 = - add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr; - const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; - const float* output_bias = - add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr; + HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. +#if GEMMA_FUSED_FFN + const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, + StridedViewBF C2, size_t worker) { + Activation(layer_config.activation, C1, range_r, range_c, C2, + env.ctx.profiler, worker); + }; + MMOptions options; + options.SetFunc(fused); + CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, + layer.gating_einsum_w2, env, activations.C1, options); +#else // Compute the hidden layer activations. - CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env, + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env, activations.C1); - CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env, + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env, activations.C2); - // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_config.activation, activations.C1, &activations.C2, env.ctx); +#endif // Hidden layer -> output layer. - CallMatMul(activations.C1, layer.linear_w, output_bias, env, - activations.ffw_out); + CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 3a20690..8957f4c 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -155,14 +155,14 @@ class MMStoreHorizontalSumsIntoC { template , class Tag, class CView> HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, const float scale, const float* HWY_RESTRICT add, - const size_t imc, Tag tag, CView C_rows) const { + const size_t imc, Tag tag, CView C_MC_NR) const { const V4 vscale = hn::Set(d4, scale); HWY_ALIGN static constexpr float kZero[4] = {}; const V4 vadd = hn::Load(d4, add ? add : kZero); - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_MC_NR); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_MC_NR); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_MC_NR); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_MC_NR); } private: @@ -202,10 +202,10 @@ class MMStoreHorizontalSumsIntoC { class Tag, class CView> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, VF4 vadd, Tag, const size_t imc, - CView C_view) { + CView C_MC_NR) { if constexpr (kRow < kRowsAC) { - using TC = hwy::RemoveCvRef; - TC* HWY_RESTRICT pos = C_view.Row(imc + kRow); + using TC = hwy::RemoveCvRef; + TC* HWY_RESTRICT pos = C_MC_NR.Row(imc + kRow); const hn::Rebind dc4; if constexpr (hwy::IsSame()) { vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value @@ -268,9 +268,9 @@ class MMDecompress { } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. We also only - // have a single MMStorage. + // have a single MMEntireA. HWY_ASSERT(options.cluster_idx == 0); - const StridedViewBF A_view = env.storage.A(A.Extents()); + const StridedViewBF A_view = env.A_BF.A(A.Extents()); AutotuneDecompressA(A, A_view, autotune, env, options); return A_view; } @@ -387,111 +387,52 @@ class MMDecompress { // Stateless, wraps member functions. Contains the innermost 2-4 loops. class MMKernel { - // Compute size of per-worker storage for `kNR` row ranges of B. Stack - // allocation avoids passing a worker index. - static constexpr size_t B_stride_max = - kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); - public: - // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is - // `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start - // at row/col 0. `CView` is either `RowPtrs` or `StridedView`. - // Called by B3A2C0 and by callers that hoist `A_view`. - template - static HWY_INLINE void A2C0(const StridedViewBF A_view, - const StridedViewBF B_view, size_t mr, - const IndexRange& range_mc, size_t kc, - const float scale, const float* HWY_RESTRICT add, - Tag tag, CView C_view) { - HWY_DASSERT(1 <= mr && mr <= kMaxMR); - - const size_t mc = range_mc.Num(); - size_t imc = 0; - - // M == 1, or x86 with 8 SIMD registers: - if (HWY_UNLIKELY(mr == 1)) { - for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - return; - } - - // AVX2 (16 registers) - if (HWY_UNLIKELY(mr == 2)) { - if (HWY_LIKELY(mc >= 2)) { - for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - } - if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - return; - } - - HWY_DASSERT(mr == 4); - if (HWY_LIKELY(mc >= 4)) { - for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - } - const size_t remainder_mc = mc - imc; - HWY_DASSERT(remainder_mc < 4); - if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); - imc += 2; - } - if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); - imc += 1; - } - HWY_DASSERT(imc == mc); - } - - static constexpr size_t B_storage_max = kNR * B_stride_max; - // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads - // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by - // `ForeachKC` and when there is only a single KC task. - template + // `mc x kc` of A, `nc x kc` of B, and updates the `mc x nc` `C_MC_NC`. + // `CView` is either `RowPtrs` or `StridedView`. + template static void B3A2C0(const StridedViewBF A, const MatPtrT& B, const IndexRange& range_mc, const IndexRange& range_kc, const IndexRange& range_nc, const MMArgs& args, - Tag out_tag, RowPtrs C) { - HWY_ALIGN BF16 B_storage[B_storage_max]; - + Tag out_tag, CView C_MC_NC) { const size_t kc = range_kc.Num(); const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); + // Upper bound on per-worker storage for `kNR` row ranges of B. Stack + // allocation avoids passing a worker index. + constexpr size_t B_stride_max = + kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); + HWY_ALIGN BF16 B_storage[kNR * B_stride_max]; const size_t B_stride = Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes); const StridedViewBF B_storage_view(B_storage, kc, B_stride); - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { + const float scale = args.scale_A * B.Scale(); + for (size_t inc = 0; inc < range_nc.Num(); inc += kNR) { + // For `add` and `B`, which are global, unlike `C_MC_NC`. + const size_t row_b = range_nc.begin() + inc; const StridedViewBF B_view = MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); - const RowPtrs C_view = C.View(range_mc.begin(), row_b); + const CView C_MC_NR = C_MC_NC.View(0, inc, kNR); const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr; - A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag, - C_view); + A2C0(A_view, B_view, args.mr, range_mc, kc, scale, add, out_tag, C_MC_NR); } } - template + template static void ForeachKC(const StridedViewBF A, const MatPtrT& B, const IndexRange& range_mc, const IndexRangePartition& ranges_kc, const IndexRange& range_nc, const MMArgs& args, - RowPtrs C) { + CView C_MC_NC) { // Peel off the first iteration of the kc loop: avoid zero-initializing `C` // by writing directly into it, and later accumulating into it. ranges_kc.VisitFirst([&](const IndexRange& range_kc) { - B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C_MC_NC); }); ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { - B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C_MC_NC); }); } @@ -585,15 +526,15 @@ class MMKernel { // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). - // Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0. - // `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also + // Updates a `kRowsAC x kNR` tile in `C_MC_NR` starting at row `imc`, column + // 0. `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also // relative to the C column. template static HWY_INLINE void LoopKC(const StridedViewBF A_view, const StridedViewBF B_view, size_t imc, size_t kc, const float scale, const float* HWY_RESTRICT add, Tag tag, - CView C_view) { + CView C_MC_NR) { const hn::ScalableTag dbf; using VBF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); @@ -777,7 +718,62 @@ class MMKernel { hn::Vec sum0, sum1, sum2, sum3; horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); - horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view); + horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_MC_NR); + } + + // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. + // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is + // `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start + // at row/col 0. + template + static HWY_INLINE void A2C0(const StridedViewBF A_view, + const StridedViewBF B_view, size_t mr, + const IndexRange& range_mc, size_t kc, + const float scale, const float* HWY_RESTRICT add, + Tag tag, CView C_MC_NR) { + HWY_DASSERT(1 <= mr && mr <= kMaxMR); + + const size_t mc = range_mc.Num(); + size_t imc = 0; + + // M == 1, or x86 with 8 SIMD registers: + if (HWY_UNLIKELY(mr == 1)) { + for (; imc < mc; ++imc) { + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + return; + } + + // AVX2 (16 registers) + if (HWY_UNLIKELY(mr == 2)) { + if (HWY_LIKELY(mc >= 2)) { + for (; imc <= mc - 2; imc += 2) { + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + } + if (HWY_UNLIKELY(imc != mc)) { + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + return; + } + + HWY_DASSERT(mr == 4); + if (HWY_LIKELY(mc >= 4)) { + for (; imc <= mc - 4; imc += 4) { + LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + } + const size_t remainder_mc = mc - imc; + HWY_DASSERT(remainder_mc < 4); + if (HWY_UNLIKELY(remainder_mc & 2)) { + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + imc += 2; + } + if (HWY_UNLIKELY(remainder_mc & 1)) { + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + imc += 1; + } + HWY_DASSERT(imc == mc); } }; @@ -813,10 +809,10 @@ class MMImpl { } public: - static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, + static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, size_t num_B, size_t vector_bytes, MatMulEnv::PerCluster& per_cluster) { - const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); + const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N, num_B); intptr_t index = IndexOfKey(key, per_cluster.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { @@ -831,17 +827,19 @@ class MMImpl { } static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N, - double t0, MMAutoTune& tuner, + size_t num_B, double t0, + MMAutoTune& tuner, const MMConfig& cfg) { const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / hwy::platform::InvariantTicksPerSecond(); - const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA + const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { - fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9, - min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), - StringFromOrder(cfg.Order()), cfg.InnerTasks()); + fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", + M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), + cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), + cfg.InnerTasks()); } if (HWY_UNLIKELY(env.print_best && tuner.Best())) { const auto ratio = [&tuner](uint64_t ticks) -> double { @@ -849,12 +847,13 @@ class MMImpl { static_cast(tuner.BestTicks()); }; const MMConfig& best = *tuner.Best(); - fprintf(stderr, - "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", - M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), - best.KC(), best.NC(), StringFromOrder(best.Order()), - best.InnerTasks(), ratio(tuner.WorstMinTicks()), - ratio(tuner.FirstConfigTicks())); + fprintf( + stderr, + "\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", + M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), + best.KC(), best.NC(), StringFromOrder(best.Order()), + best.InnerTasks(), ratio(tuner.WorstMinTicks()), + ratio(tuner.FirstConfigTicks())); } } @@ -874,10 +873,11 @@ class MMImpl { class MMLoops { public: // Called from `MatMul` from two places: either with the next autotune config, - // or with the best config. + // or with the best config. `B2` is null unless called from `TwoMatMul`. template static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, args.env.ctx.Worker(args.options.cluster_idx), zone); @@ -885,7 +885,7 @@ class MMLoops { DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { - Loop(order, parallel, A, B, C, args); + Loop(order, parallel, A, B, B2, C, args); }); }); } @@ -901,18 +901,14 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); - const IndexRange& range_M = args.ranges_mc.Range(0); - const IndexRange& range_K = args.ranges_kc.Range(0); - const size_t K = range_K.Num(); - const StridedViewBF A_view = A.View(range_M.begin(), 0, K); - const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes); + const IndexRange& range_mc = args.ranges_mc.Range(0); + const IndexRange& range_kc = args.ranges_kc.Range(0); - // Similar to `B3A2C0`, but here we hoisted `A_view`. parallel.ForN( args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks, args.options.cluster_idx, @@ -920,26 +916,19 @@ class MMLoops { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS - const StridedViewBF B_storage_view(B_storage, K, B_stride); + MMKernel::B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), + C.View(0, range_nc.begin(), range_nc.Num())); - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - const StridedViewBF B_view = - MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); - const RowPtrs C_view = C.View(range_M.begin(), row_b); - const float* HWY_RESTRICT add = - args.add ? args.add + row_b : nullptr; + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); - MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add, - MMSetC(), C_view); + if (B2 != nullptr) { + MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args, + MMSetC(), C2); } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_M, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); } }); } @@ -948,7 +937,8 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -959,14 +949,21 @@ class MMLoops { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, - range_nc, args, C); + MMKernel::ForeachKC( + A, B, range_mc, args.ranges_kc, range_nc, args, + C.View(0, range_nc.begin(), range_nc.Num())); + + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, + range_nc, args, C2); + } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_mc, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, + worker); } }); } @@ -976,10 +973,11 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); - const IndexRange& range_K = args.ranges_kc.Range(0); + const IndexRange& range_kc = args.ranges_kc.Range(0); parallel.ForRangesMC_NC( args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, @@ -987,14 +985,19 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(), - C); + MMKernel::B3A2C0( + A, B, range_mc, range_kc, range_nc, args, MMSetC(), + C.View(range_mc.begin(), range_nc.begin(), range_nc.Num())); + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args, + MMSetC(), C2); + } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_mc, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); } }); } @@ -1004,7 +1007,8 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); parallel.ForRangesMC_NC( @@ -1013,14 +1017,20 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args, - C); + MMKernel::ForeachKC( + A, B, range_mc, args.ranges_kc, range_nc, args, + C.View(range_mc.begin(), range_nc.begin(), range_nc.Num())); + + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc, + args, C2); + } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_mc, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); } }); } @@ -1060,20 +1070,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); + const size_t num_B = 1; const CacheInfo& cache = env.ctx.cache_info; - MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(), - env.per_cluster[cluster_idx]); + MMPerKey& per_key = MMImpl::FindOrAddPerKey( + M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); + MatPtrT* B2 = nullptr; // required for type matching + MMAutoTune& tuner = per_key.autotune; if (HWY_LIKELY(tuner.Best())) { - const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), - add, options, tuner, *tuner.Best()); - MMLoops::Dispatch(A_view, B, C_rows, args); + const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, + *tuner.Best()); + MMLoops::Dispatch(A_view, B, B2, C_rows, args); return &per_key; } @@ -1082,20 +1095,83 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // Ensure matrix dimensions match each other (off the hot path). HWY_ASSERT(K == B.Cols()); HWY_ASSERT(M <= kMaxBatchSize); - HWY_ASSERT(K <= MMStorage::kMaxK); + HWY_ASSERT(K <= MMEntireA::kMaxK); HWY_ASSERT(N % kNR == 0); MMImpl::EnsureAligned(A, cache.VectorBytes()); tuner.SetCandidates( - MMCandidates(cache, M, K, N, sizeof(TC), env.print_config)); + MMCandidates(cache, M, K, N, num_B, sizeof(TC), env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); - const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), - add, options, tuner, cfg); + const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, cfg); const uint64_t t0 = hwy::timer::Start(); - MMLoops::Dispatch(A_view, B, C_rows, args); - MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg); + MMLoops::Dispatch(A_view, B, B2, C_rows, args); + MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg); + + return &per_key; +} + +// Performs A*B1 and A*B2 in parallel. This is useful for gated FFNs. +// Differences vs MatMul: The second result matrix is not materialized, it is +// only passed to the `options.func` callback. There is no `add` argument +// because it is not required for this use case. There is no default `options` +// argument because `options.func` must be set by the caller. +template +HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT& A, const MatPtrT& B1, + const MatPtrT& B2, MatMulEnv& env, + MatPtrT& C, MMOptions options) { + static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul"); + const size_t cluster_idx = options.cluster_idx; + HWY_DASSERT(cluster_idx < env.row_ptrs.size()); + PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); + + HWY_DASSERT(options.func != nullptr); // no other way to get access to C2. + + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); + + const size_t M = A.Rows(); + const size_t K = A.Cols(); + const size_t N = B1.Rows(); + const size_t num_B = 2; + + const CacheInfo& cache = env.ctx.cache_info; + MMPerKey& per_key = MMImpl::FindOrAddPerKey( + M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); + + // (Also auto-tunes, hence outside the timed section to prevent interference.) + const StridedViewBF A_view(A, 0, 0, A.Cols()); + + MMAutoTune& tuner = per_key.autotune; + if (HWY_LIKELY(tuner.Best())) { + // Only A scale - B1/B2 may differ, and are passed separately. + const MMArgs args(env, M, K, N, A.Scale(), + /*add=*/nullptr, options, tuner, *tuner.Best()); + MMLoops::Dispatch(A_view, B1, &B2, C_rows, args); + return &per_key; + } + + // Autotuning, first call: enumerate all feasible configs. + if (HWY_UNLIKELY(!tuner.HasCandidates())) { + // Ensure matrix dimensions match each other (off the hot path). + HWY_ASSERT(K == B1.Cols()); + HWY_ASSERT(K == B2.Cols()); + HWY_ASSERT(M <= kMaxBatchSize); + HWY_ASSERT(K <= MMEntireA::kMaxK); + HWY_ASSERT(N % kNR == 0); + MMImpl::EnsureAligned(A, cache.VectorBytes()); + tuner.SetCandidates( + MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config)); + } + + const MMConfig& cfg = tuner.NextConfig(); + // Only A scale - B1/B2 may differ, and are passed separately. + const MMArgs args(env, M, K, N, A.Scale(), /*add=*/nullptr, options, tuner, + cfg); + + const uint64_t t0 = hwy::timer::Start(); + MMLoops::Dispatch(A_view, B1, &B2, C_rows, args); + MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg); return &per_key; } diff --git a/ops/matmul.cc b/ops/matmul.cc index 66ce0df..6ef1412 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -63,11 +63,12 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, class GenerateCandidates { public: GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N, - size_t sizeof_TC, bool print_config) + size_t num_B, size_t sizeof_TC, bool print_config) : cache_(cache), M_(M), K_(K), N_(N), + num_B_(num_B), sizeof_TC_(sizeof_TC), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line @@ -150,7 +151,7 @@ class GenerateCandidates { } } - // The number of A and B columns to read between updating `partial`. + // The number of A and B columns to read between updating `C`. SizeVec KC(size_t mr, MMOrder order) const { // `LoopKC` handles up to `mr` rows of A. const size_t rows_a = HWY_MIN(M_, mr); @@ -164,7 +165,7 @@ class GenerateCandidates { // TB=NUQ due to less amortization of the table loads. Due to the low L1 // latency, the packing is still effectively fused into `LoopKC`. It may // be better to round up and accept a few L2 accesses in exchange for - // fewer loops over K, and thus fewer writes to `partial`. Hence we do not + // fewer loops over K, and thus fewer writes to `C`. Hence we do not // subtract the output and buf, and allow using more than the actual L1 // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. @@ -179,7 +180,7 @@ class GenerateCandidates { // Avoid proposing kc > K. if (K_ > kc_multiple_) { - // Generally it is best to use the full `kc` (fewer writes to `partial`), + // Generally it is best to use the full `kc` (fewer writes to `C`), // but a bit less can be better if it evenly divides `K`, or enables an // `mc` that evenly divides `M`. Try several smaller values. @@ -196,7 +197,7 @@ class GenerateCandidates { } if (print_config_ && all_kc.size() > 1) { - fprintf(stderr, "KC: "); + fprintf(stderr, "num_B %zu: KC: ", num_B_); for (size_t kc : all_kc) { fprintf(stderr, "%zu ", kc); } @@ -214,18 +215,18 @@ class GenerateCandidates { // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // packed B. We want `mc * kc` elements of A to fit in L2, alongside - // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of - // partial. + // `bytes_b` plus `mc` cache lines because resident-A updates `mc` C rows. const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes(); size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); - mc_max = HWY_MIN(mc_max, kMaxBatchSize); + mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC)); HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, M_); mc_max = hwy::RoundDownTo(mc_max, mr); SizeVec all_mc(1, mc_max); - // Larger MC is better for non-blocks, otherwise we want more small options. - const size_t reps = !IsBlock(order) ? 2 : 3; + // Larger MC is better for non-blocks, otherwise we want more small options, + // especially for two B. + const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_); size_t prev = mc_max; for (size_t rep = 0; rep < reps; ++rep) { @@ -240,7 +241,7 @@ class GenerateCandidates { } if (print_config_ && all_mc.size() > 1) { - fprintf(stderr, "MC: "); + fprintf(stderr, "num_B %zu: MC: ", num_B_); for (size_t mc : all_mc) { fprintf(stderr, "%zu ", mc); } @@ -252,14 +253,15 @@ class GenerateCandidates { // The number of (possibly L3 resident) B rows per `NT_MT` task. SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { - size_t nc_max = N_; + size_t nc_max = kMaxNC; // Only if there will be reuse of B: choose the largest `nc_max` (C cols) - // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. - // Otherwise, leave it unbounded. + // such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise, + // leave it unbounded. if (M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); - nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), N_); + nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC); } + nc_max = HWY_MIN(nc_max, N_); HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); @@ -278,7 +280,7 @@ class GenerateCandidates { if (N_ > nc_multiple_) { // Large L3, but its behavior and characteristics varies across platforms, // hence autotune a wider range of nc than the other dimensions. - size_t reps = 10; + size_t reps = 9 + num_B_; // For small M, we can afford larger NC, hence allow fewer small options. if (M_ <= 2 * mr) reps -= 1; @@ -301,7 +303,7 @@ class GenerateCandidates { } if (print_config_ && all_nc.size() > 1) { - fprintf(stderr, "NC: "); + fprintf(stderr, "num_B %zu: NC: ", num_B_); for (size_t nc : all_nc) { fprintf(stderr, "%zu ", nc); } @@ -329,6 +331,7 @@ class GenerateCandidates { const size_t M_; const size_t K_; const size_t N_; + const size_t num_B_; const size_t sizeof_TC_; const size_t kc_multiple_; @@ -341,12 +344,13 @@ class GenerateCandidates { // Facade to avoid exposing `GenerateCandidates` in the header. std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, - size_t N, size_t sizeof_TC, + size_t N, size_t num_B, size_t sizeof_TC, bool print_config) { - return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)(); + return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)(); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) { +MatMulEnv::MatMulEnv(ThreadingContext& ctx) + : ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) { const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); per_cluster.resize(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { diff --git a/ops/matmul.h b/ops/matmul.h index 93e7b04..bedee3d 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -21,7 +21,6 @@ #include #include -#include #include // IWYU pragma: begin_exports @@ -54,7 +53,9 @@ HWY_INLINE_VAR constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. HWY_INLINE_VAR constexpr size_t kMaxMR = 4; -HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink? +// For `MMTilesC`. +HWY_INLINE_VAR constexpr size_t kMaxMC = 512; +HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. @@ -108,9 +109,9 @@ struct MMParallelWithinCluster { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t base = ctx.Worker(cluster_idx); - const IndexRangePartition worker_ranges = StaticPartition( + const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); - ParallelizeOneRange(worker_ranges, cluster, + ParallelizeOneRange(ranges_n, cluster, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, base + worker); }); @@ -169,20 +170,20 @@ struct MMParallelHierarchical { if (num_clusters == 1) { const size_t cluster_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const IndexRangePartition worker_ranges = StaticPartition( + const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); return ParallelizeOneRange( - worker_ranges, cluster, + ranges_n, cluster, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, worker); }); } // Assign each cluster a sub-range of `range_n` (typically hundreds). - const IndexRangePartition n_ranges = + const IndexRangePartition ranges_n = StaticPartition(range_n, num_clusters, n_multiple); ParallelizeOneRange( - n_ranges, all_clusters, + ranges_n, all_clusters, [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = ctx.Worker(cluster_idx); @@ -274,32 +275,51 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); // C is BF16/float. void BindC(ThreadingContext& ctx, MatPtr& C); -// For A. -class MMStorage { +// Space for converting A=F32 to BF16 before the matmul. This is faster than +// on-the-fly when native BF16 is available: it only happens once, not per B +// tile row, and the cache footprint is smaller. +class MMEntireA { public: // Compile-time bounds on matrix columns to enable pre-allocating storage // and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B. static constexpr size_t kMaxK = 36 * 1024; - MMStorage(const Allocator& allocator) + explicit MMEntireA(const Allocator& allocator) // 288 MiB. Must be padded, see `DoDecompressA`. : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd) {} - // Returns matrix view. Converting A=F32 to BF16 up-front is faster than - // on-the-fly when native BF16 is available: it only happens once, not per B - // tile row, and the cache footprint is smaller. StridedViewBF A(const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxBatchSize); - HWY_DASSERT(extents.cols <= kMaxK); - return StridedViewBF(const_cast(A_.Row(0)), extents.cols, - A_.Stride()); + return StridedViewBF(A_, 0, 0, extents.cols); } private: MatStorageT A_; }; +// One tile of C per *worker* (required for `kNT_MT*`). +class MMTilesC { + public: + explicit MMTilesC(const ThreadingContext& ctx) { + const size_t max_workers = ctx.pools.MaxWorkers(); + C_.reserve(max_workers); + for (size_t worker = 0; worker < max_workers; ++worker) { + C_.push_back(MatStorageT("Ctile", Extents2D(kMaxBatchSize, kMaxNC), + ctx.allocator, MatPadding::kOdd)); + } + } + + StridedViewBF C(const Extents2D& extents, size_t worker) const { + HWY_DASSERT(extents.rows <= kMaxBatchSize); + HWY_DASSERT(worker < C_.size()); + return StridedViewBF(C_[worker], 0, 0, extents.cols); + } + + private: + std::vector> C_; +}; + //------------------------------------------------------------------------------ // Autotuning @@ -471,7 +491,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, - size_t N, size_t sizeof_TC, + size_t N, size_t num_B, size_t sizeof_TC, bool print_config); // State machine for choosing the best `TConfig`, which is `MMConfig` for the @@ -595,12 +615,14 @@ class MMKeys { static constexpr Key kPadding = 0; // Compresses the dimensions into a single Key for faster comparison. - static Key KeyFromDims(size_t M, size_t K, size_t N) { + static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) { HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller - HWY_DASSERT(K < (Key{1} << 24)); - HWY_DASSERT(N < (Key{1} << 24)); + HWY_DASSERT(K < (Key{1} << 20)); + HWY_DASSERT(N < (Key{1} << 20)); + HWY_DASSERT(num_B == 1 || num_B == 2); const Key key = static_cast(BucketM(M)) | (static_cast(K) << 16) | - (static_cast(N) << 40); + (static_cast(N) << 40) | + (static_cast(num_B) << 60); HWY_DASSERT(key != kPadding); return key; } @@ -643,10 +665,6 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - // Only profile if enabled and the main autotuner finished. `autotune_par_a` - // might not be active if inputs are all BF16. - bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); } - MMAutoTune autotune; MMAutoTune autotune_par_a; }; @@ -666,12 +684,15 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - MMStorage storage; + MMEntireA A_BF; + MMTilesC C_tiles; struct PerCluster { MMKeys keys; std::vector per_key; - HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing + // Prevents false sharing. + HWY_MAYBE_UNUSED uint8_t + padding[HWY_ALIGNMENT - sizeof(MMKeys) - sizeof(per_key)]; }; std::vector per_cluster; @@ -687,31 +708,57 @@ struct MatMulEnv { std::vector> row_ptrs; }; -// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols) -// that this thread has just filled, a view into a second tile (only for the -// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`). -using MMFused = std::function; +// Called via `CallClosure`, which consumes the first (opaque) argument. User +// functions are called with the entire C matrix, the sub-ranges of M (rows) +// and N (cols) that this thread has just filled, a view into a second tile +// (only for `TwoMatmul`), and the worker thread index (see `ParallelFor`). +typedef void (*MMFunc)(const void* opaque, RowPtrsBF, IndexRange, IndexRange, + StridedViewBF, size_t); + +class MMOptions { + // Same technique as in `hwy::ThreadPool` and C++23 `std::function_ref`: + // type-erasure without allocation. + template + static void CallClosure(const void* opaque, RowPtrsBF C1, IndexRange range_r, + IndexRange range_c, StridedViewBF C2, size_t worker) { + (*reinterpret_cast(opaque))(C1, range_r, range_c, C2, + worker); + } + + public: + // `closure` must remain alive until the end of (Two)MatMul. + template + void SetFunc(const Closure& closure) { + func = static_cast(&CallClosure); + opaque = &closure; + } + + void MaybeCallFunc(RowPtrsBF C1, IndexRange range_r, IndexRange range_c, + StridedViewBF C2, size_t worker) const { + if (func != nullptr) { + func(opaque, C1, range_r, range_c, C2, worker); + } + } + + MMFunc func = nullptr; // called if non-null and `TC` is BF16. + const void* opaque = nullptr; -struct MMOptions { uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; - - MMFused fused; // called if non-null and `TC` is BF16. }; // Arguments to MatMul() that are independent of the A/B/C types. Reduces // register pressure compared to individual values/references. Also used for // passing through `DispatchOrder`. struct MMArgs { - MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale, + MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, float scale_A, const float* HWY_RESTRICT add, MMOptions options, const MMAutoTune& autotune, const MMConfig& config) : env(env), line_bytes(env.ctx.cache_info.LineBytes()), range_n(0, N), - scale(scale), + scale_A(scale_A), add(add), options(options), @@ -728,7 +775,8 @@ struct MMArgs { // MatMul arguments: const IndexRange range_n; // entire N - const double scale; + // There can be two B, so do not yet multiply together the A and B scales. + const float scale_A; const float* HWY_RESTRICT add; const MMOptions options; diff --git a/ops/matmul_static-inl.h b/ops/matmul_static-inl.h index ba09e0c..abb6f43 100644 --- a/ops/matmul_static-inl.h +++ b/ops/matmul_static-inl.h @@ -53,6 +53,14 @@ namespace HWY_NAMESPACE { // included from matmul_static_*.cc. GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT +HWY_MAYBE_UNUSED void TwoMatMulStatic(const MatPtrT& A, // NOLINT + const MatPtrT& B1, + const MatPtrT& B2, + MatMulEnv& env, MatPtrT& C, + MMOptions options) { + TwoMatMul(A, B1, B2, env, C, options); +} + } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); diff --git a/ops/matmul_static.h b/ops/matmul_static.h index 61dc505..6b93d92 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -37,13 +37,19 @@ const float* HWY_RESTRICT add, MatMulEnv& env, \ MatPtrT& C, MMOptions options); +#define GEMMA_MATMUL_FOR_B(TB) \ + GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, TB) \ + void TwoMatMulStatic(const MatPtrT& A, const MatPtrT& B1, \ + const MatPtrT& B2, MatMulEnv& env, \ + MatPtrT& C, MMOptions options); + // Passed to HWY_VISIT_TARGETS; declares all overloads for all targets. #define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \ namespace NAMESPACE { \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \ + GEMMA_MATMUL_FOR_B(BF16) \ + GEMMA_MATMUL_FOR_B(float) \ + GEMMA_MATMUL_FOR_B(NuqStream) \ + GEMMA_MATMUL_FOR_B(SfpStream) \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 373f8aa..2f0fde2 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -29,6 +29,8 @@ #include #include +#include + #include "ops/matmul.h" #include "util/basics.h" #include "util/mat.h" @@ -246,7 +248,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatStorageT C_slow("C_slow", C_extents, env.ctx.allocator, MatPadding::kOdd); MatStorageT C("C", C_extents, env.ctx.allocator, MatPadding::kOdd); + MatStorageT C2("C", C_extents, env.ctx.allocator, MatPadding::kOdd); C.AllocateAndAttachRowPtrs(env.row_ptrs); + C2.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT add_storage = add ? GenerateMat(Extents2D(1, cols_bc), env.ctx.allocator, @@ -262,7 +266,48 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, for (size_t rep = 0; rep < 16; ++rep) { MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options); AssertClose(A, BT, C_slow, C, env, line); - if (per_key->autotune.Best()) break; + // Check before TwoMatMulStatic(), which can invalidate per_key. + const bool autotune_done = !!per_key->autotune.Best(); + + // Ensure the tiled view returns the same result as C. + if constexpr (IsBF16() && IsBF16()) { + // The total view area should match the entire C matrix. + std::atomic total_view_area = 0; + + const auto fused = [&](RowPtrsBF C2_rows, IndexRange range_r, + IndexRange range_c, StridedViewBF C2_view, + size_t worker) { + total_view_area.fetch_add(range_r.Num() * range_c.Num()); + HWY_ASSERT(range_c.Num() <= C2_view.Cols()); + HWY_ASSERT(worker < env.ctx.pools.MaxWorkers()); + for (size_t ir = 0; ir < range_r.Num(); ++ir) { + const size_t r = range_r.begin() + ir; + for (size_t ic = 0; ic < range_c.Num(); ++ic) { + const size_t c = range_c.begin() + ic; + const float expected = + hwy::ConvertScalarTo(C2_rows.Row(r)[c]); + const float actual = + hwy::ConvertScalarTo(C2_view.Row(ir)[ic]); + const float L1 = hwy::ScalarAbs(actual - expected); + if (L1 > 1E-6f) { + HWY_ABORT("%zu: ir %zu ic %zu L1 %f expected %f actual %f.", + worker, ir, ic, L1, expected, actual); + } + } + } + }; + options.SetFunc(fused); + TwoMatMulStatic(A, BT, BT, env, C2, options); + HWY_ASSERT_EQ(C.Extents().Area(), total_view_area.load()); + options.func = nullptr; // reset for next call + + // TwoMatMulStatic() does not support adding a bias vector. + if (!add) { + AssertClose(A, BT, C, C2, env, line); + } + } + + if (autotune_done) break; } } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index cfd85ae..1593aa4 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -69,6 +69,14 @@ MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, }); } +static inline void CallTwoMatMul(const MatPtrT& A, const MatPtr& B1, + const MatPtr& B2, MatMulEnv& env, + MatPtrT& C, const MMOptions& options) { + return CallUpcastedSame(&B1, &B2, [&](const auto* B1_t, const auto* B2_t) { + return TwoMatMulStatic(A, *B1_t, *B2_t, env, C, options); + }); +} + HWY_INLINE double PackTokenAndProb(int32_t token, float prob) { // casting prob from float to double just makes some changes to the // exponent bias and pads zeros in the mantissa. diff --git a/util/mat.h b/util/mat.h index c2427e5..6f9a243 100644 --- a/util/mat.h +++ b/util/mat.h @@ -40,10 +40,11 @@ class RowPtrs { public: RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {} - RowPtrs View(size_t r, size_t c) { + // Extra argument is for compatibility with `StridedView`. + RowPtrs View(size_t r, size_t c, size_t /*cols*/) { RowPtrs view(row_ptrs_); - view.r0_ = static_cast(r); - view.c0_ = static_cast(c); + view.r0_ = static_cast(r0_ + r); + view.c0_ = static_cast(c0_ + c); return view; } @@ -531,7 +532,11 @@ class StridedView { : row0_(row0), cols_(static_cast(cols)), stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); + if constexpr (HWY_IS_DEBUG_BUILD) { + if (stride < cols) { + HWY_ABORT("stride %zu < cols %zu", stride, cols); + } + } } // Returns 2D subrange whose top-left is `r, c` and width is `cols`. From b603425bf30b93dd8a88e4865b6353c471b8af43 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 16 Sep 2025 08:01:21 -0700 Subject: [PATCH 45/65] Fix batch inference: dangling reference Also add more detailed asserts/error messages. PiperOrigin-RevId: 807695421 --- evals/benchmark_helper.cc | 26 +++++++++++++++++--------- evals/benchmark_helper.h | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index bd53845..abdef50 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -105,8 +105,14 @@ std::vector GemmaEnv::BatchQueryModel( const size_t pos, const int token, float) { HWY_ASSERT(query_index < num_queries); + if (token >= gemma_.Config().vocab_size) { + HWY_ABORT("Token %d >= vocab size %d", token, gemma_.Config().vocab_size); + } std::string token_text; - HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector{token}, &token_text)); + if (!gemma_.Tokenizer().Decode(std::vector{token}, &token_text)) { + HWY_ABORT("Failed to decode token %d, tokenizer bytes %s\n", token, + gemma_.Tokenizer().Serialize().substr(0, 10).c_str()); + } res[query_index].response.append(token_text); HWY_ASSERT(pos == res[query_index].tokens_generated); res[query_index].tokens_generated += 1; @@ -143,17 +149,19 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) { } std::vector GemmaEnv::BatchQueryModel( - const std::vector& inputs) { - std::vector prompt_vector; - prompt_vector.reserve(inputs.size()); + const std::vector& prompt_strings) { + std::vector views; + views.reserve(prompt_strings.size()); - for (auto& input : inputs) { - std::vector prompt = WrapAndTokenize(input); - prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); + std::vector> storage; + storage.reserve(prompt_strings.size()); + for (auto& input : prompt_strings) { + storage.push_back(WrapAndTokenize(input)); + views.push_back(PromptTokens(storage.back().data(), storage.back().size())); } - QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size()); - return BatchQueryModel(prompt_span); + QueriesPromptTokens span_of_views(views.data(), views.size()); + return BatchQueryModel(span_of_views); } float GemmaEnv::CrossEntropy(const std::string& input) { diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 81ccde6..75cf0d2 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -89,7 +89,7 @@ class GemmaEnv { // Adds turn structure to input, tokenizes and calls the above overload. QueryResult QueryModel(const std::string& input); std::vector BatchQueryModel( - const std::vector& inputs); + const std::vector& prompt_strings); // Runs inference on the given input and calls the callback for each token. void QueryModel(const std::vector& tokens, From 501fdf000ea77544a0ee9b4db27767fcc10ebc02 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 19 Sep 2025 09:02:44 -0700 Subject: [PATCH 46/65] Remove no longer used MatVec PiperOrigin-RevId: 809059409 --- BUILD.bazel | 23 --- CMakeLists.txt | 2 - gemma/flash_attention_test.cc | 1 - ops/gemma_matvec_test.cc | 192 --------------------- ops/matvec-inl.h | 302 ---------------------------------- 5 files changed, 520 deletions(-) delete mode 100644 ops/gemma_matvec_test.cc delete mode 100644 ops/matvec-inl.h diff --git a/BUILD.bazel b/BUILD.bazel index 02c54bd..d5bac73 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -366,7 +366,6 @@ cc_library( "ops/dot-inl.h", "ops/sum-inl.h", "ops/fp_arith-inl.h", - "ops/matvec-inl.h", "ops/ops-inl.h", ], deps = [ @@ -381,7 +380,6 @@ cc_library( "@highway//:bit_set", "@highway//:hwy", "@highway//:math", - "@highway//:matvec", "@highway//:profiler", "@highway//:thread_pool", "@highway//hwy/contrib/sort:vqsort", @@ -442,27 +440,6 @@ cc_test( ], ) -cc_test( - name = "gemma_matvec_test", - size = "small", - timeout = "long", - srcs = ["ops/gemma_matvec_test.cc"], - linkstatic = True, - local_defines = ["HWY_IS_TEST"], - # for test_suite. - tags = ["ops_tests"], - deps = [ - ":mat", - ":ops", - ":threading_context", - "@googletest//:gtest_main", # buildcleaner: keep - "//compression:compress", - "@highway//:hwy", - "@highway//:hwy_test_util", - "@highway//:thread_pool", - ], -) - cc_test( name = "matmul_test", size = "small", diff --git a/CMakeLists.txt b/CMakeLists.txt index cb2911f..46242f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,7 +112,6 @@ set(SOURCES ops/matmul-inl.h ops/matmul.cc ops/matmul.h - ops/matvec-inl.h ops/ops-inl.h ops/ops.h ops/sum-inl.h @@ -224,7 +223,6 @@ set(GEMMA_TEST_FILES io/fields_test.cc ops/bench_matmul.cc ops/dot_test.cc - ops/gemma_matvec_test.cc ops/matmul_test.cc ops/ops_test.cc paligemma/image_test.cc diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index efb210e..d4d6380 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -51,7 +51,6 @@ #include "gemma/attention.h" #include "gemma/configs.h" #include "gemma/flash_attention.h" -#include "ops/matvec-inl.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); diff --git a/ops/gemma_matvec_test.cc b/ops/gemma_matvec_test.cc deleted file mode 100644 index e55539d..0000000 --- a/ops/gemma_matvec_test.cc +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "compression/types.h" -#ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS -#endif // HWY_DISABLED_TARGETS - -#include -#include - -#include // std::max -#include // std::abs -#include - -#include "util/mat.h" -#include "util/threading_context.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "ops/gemma_matvec_test.cc" // NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -// After highway.h -#include "compression/compress-inl.h" -#include "ops/matvec-inl.h" -#include "hwy/tests/test_util-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -using FloatPtr = hwy::AlignedFreeUniquePtr; - -FloatPtr SimpleMatVecAdd(const MatStorageT& mat, const FloatPtr& vec, - const FloatPtr& add) { - const size_t num = mat.Rows() * mat.Cols(); - FloatPtr raw_mat = hwy::AllocateAligned(num); - FloatPtr out = hwy::AllocateAligned(mat.Rows()); - HWY_ASSERT(raw_mat && out); - const hn::ScalableTag df; - DecompressAndZeroPad(df, mat.Span(), 0, raw_mat.get(), num); - for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) { - out[idx_row] = 0.0f; - for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) { - out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col]; - } - out[idx_row] *= mat.Scale(); - out[idx_row] += add[idx_row]; - } - return out; -} - -template -std::unique_ptr> GenerateMat(size_t offset, - const Allocator& allocator, - hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - const Extents2D extents(kOuter, kInner); - auto mat = std::make_unique>("TestMat", extents, allocator, - MatPadding::kPacked); - FloatPtr raw_mat = hwy::AllocateAligned(extents.Area()); - HWY_ASSERT(raw_mat); - const float scale = 1.0f / kInner; - pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { - for (size_t j = 0; j < kInner; j++) { - raw_mat[i * kInner + j] = - static_cast((i * kInner + j + offset) * scale); - } - }); - - Compress(raw_mat.get(), extents.Area(), ws, mat->Span(), 0, pool); - mat->SetScale(1.9f); // Arbitrary value, different from 1. - return mat; -} - -template -FloatPtr GenerateVec(size_t offset) { - FloatPtr vec = hwy::AllocateAligned(length); - HWY_ASSERT(vec); - for (size_t idx = 0; idx < length; idx++) { - vec[idx] = static_cast(idx + offset); - } - return vec; -} - -template -void AssertClose(const FloatPtr& a, const FloatPtr& b) { - for (size_t idx = 0; idx < length; idx++) { - const float rel_abs_delta = std::abs(a[idx] - b[idx]) / - std::max(std::abs(a[idx]), std::abs(b[idx])); - EXPECT_LT(rel_abs_delta, 2e-6) - << "a[" << idx << "]=" << a[idx] << ", b[" << idx << "]=" << b[idx]; - } -} - -void TestMatVecAdd() { - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - hwy::ThreadPool& pool = ctx.pools.Pool(); - constexpr size_t kOuter = 128 * 3; - constexpr size_t kInner = 128 * 5; - auto mat = GenerateMat(0, ctx.allocator, pool); - FloatPtr vec = GenerateVec(0); - FloatPtr add = GenerateVec(0); - FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add); - FloatPtr actual_out = hwy::AllocateAligned(kOuter); - HWY_ASSERT(vec && add && expected_out && actual_out); - MatVecAdd(*mat, 0, kOuter, kInner, vec.get(), add.get(), actual_out.get(), - pool); - AssertClose(actual_out, expected_out); -} - -void TestTwoMatVecAdd() { - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - hwy::ThreadPool& pool = ctx.pools.Pool(); - constexpr size_t kOuter = 128 * 3; - constexpr size_t kInner = 128 * 5; - auto mat0 = GenerateMat(0, ctx.allocator, pool); - auto mat1 = GenerateMat(1, ctx.allocator, pool); - FloatPtr vec = GenerateVec(0); - FloatPtr add0 = GenerateVec(0); - FloatPtr add1 = GenerateVec(1); - FloatPtr expected_out0 = SimpleMatVecAdd(*mat0, vec, add0); - FloatPtr expected_out1 = SimpleMatVecAdd(*mat1, vec, add1); - FloatPtr actual_out0 = hwy::AllocateAligned(kOuter); - FloatPtr actual_out1 = hwy::AllocateAligned(kOuter); - HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && - expected_out1 && actual_out1); - TwoMatVecAdd(*mat0, *mat1, 0, kOuter, kInner, vec.get(), add0.get(), - add1.get(), actual_out0.get(), actual_out1.get(), pool); - AssertClose(actual_out0, expected_out0); - AssertClose(actual_out1, expected_out1); -} - -void TestTwoOfsMatVecAddLoop() { - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - hwy::ThreadPool& pool = ctx.pools.Pool(); - - constexpr size_t kOuter = 128 * 3; - constexpr size_t kInner = 128 * 5; - auto mat = GenerateMat(0, ctx.allocator, pool); - FloatPtr vec = GenerateVec(0); - FloatPtr add0 = GenerateVec(0); - FloatPtr add1 = GenerateVec(1); - FloatPtr expected_out0 = SimpleMatVecAdd(*mat, vec, add0); - FloatPtr expected_out1 = SimpleMatVecAdd(*mat, vec, add1); - FloatPtr actual_out0 = hwy::AllocateAligned(kOuter); - FloatPtr actual_out1 = hwy::AllocateAligned(kOuter); - HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && - expected_out1 && actual_out1); - TwoOfsMatVecAddLoop(*mat, 0, 0, kOuter, kInner, vec.get(), add0.get(), - add1.get(), actual_out0.get(), actual_out1.get()); - AssertClose(actual_out0, expected_out0); - AssertClose(actual_out1, expected_out1); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE - -namespace gcpp { -HWY_BEFORE_TEST(MatVecTest); -HWY_EXPORT_AND_TEST_P(MatVecTest, TestMatVecAdd); -HWY_EXPORT_AND_TEST_P(MatVecTest, TestTwoMatVecAdd); -HWY_EXPORT_AND_TEST_P(MatVecTest, TestTwoOfsMatVecAddLoop); -HWY_AFTER_TEST(); - -} // namespace gcpp - -#endif diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h deleted file mode 100644 index c8feda9..0000000 --- a/ops/matvec-inl.h +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_ - -#include -#include -#include - -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" - -#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_ - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_MATVEC_TOGGLE -#endif - -#include "compression/compress-inl.h" -#include "ops/dot-inl.h" -#include "util/mat.h" // MatPtrT -#include "hwy/contrib/math/math-inl.h" -#include "hwy/contrib/matvec/matvec-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - -// For callers that pass `MatPtrT`, which is not necessarily packed - callers -// should use Stride() to compute `w_ofs`. -template -HWY_INLINE float Dot(const MatPtrT& w, size_t w_ofs, const VT* vec_aligned, - size_t num) { - const hn::ScalableTag d; - return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num); -} - -// ArrayT is MatPtrT. - -// Simple version without tiling nor threading, but two offsets/outputs and -// always with addition. -template -HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0, - const size_t mat_ofs1, const size_t outer, - const size_t inner, - const VecT* HWY_RESTRICT vec_aligned, - const AddT* HWY_RESTRICT add0, - const AddT* HWY_RESTRICT add1, - float* HWY_RESTRICT out0, - float* HWY_RESTRICT out1) { - PROFILER_ZONE("TwoOfsMatVecAddLoop"); - - for (size_t idx_row = 0; idx_row < outer; ++idx_row) { - const size_t row_ofs0 = mat_ofs0 + idx_row * mat.Stride(); - const size_t row_ofs1 = mat_ofs1 + idx_row * mat.Stride(); - out0[idx_row] = hwy::ConvertScalarTo(add0[idx_row]) + - Dot(mat, row_ofs0, vec_aligned, inner); - out1[idx_row] = hwy::ConvertScalarTo(add1[idx_row]) + - Dot(mat, row_ofs1, vec_aligned, inner); - } -} - -HWY_INLINE constexpr size_t MaxCols() { - // Vec + mat rows should fit into 32 KiB L1. - return 2048; -} - -template -HWY_INLINE constexpr size_t RowsPerStrip() { - // Aim for 128 work items to reduce pool overhead. Must be at least one - // vector; prefer a power of two for faster division. - constexpr size_t kLanes = hn::ScalableTag().MaxLanes(); - constexpr size_t kRowsPerStrip = - kOuter < 128 ? kLanes - : HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(kOuter / 128)); - return kRowsPerStrip; -} - -HWY_INLINE size_t RowsPerStrip(const size_t outer) { - // Aim for 128 work items to reduce pool overhead. Must be at least one - // vector; prefer a power of two for faster division. - constexpr size_t kLanes = hn::ScalableTag().MaxLanes(); - return outer < 128 ? kLanes - : HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(outer / 128)); -} - -namespace detail { - -// For each i = [0, num_rows), compute partial (length `num_cols`) dot product -// of row i with `vec_aligned` and add into `out[i]`. The upper-left -// coordinate of the tile is r0, c0. -template -HWY_INLINE void AccumulatePartialDotProducts( - DF df, const ArrayT& mat, size_t mat_ofs, size_t r0, size_t c0, - size_t num_rows, size_t num_cols, const VecT* HWY_RESTRICT vec_aligned, - float* HWY_RESTRICT out) { - for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { - const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat.Stride(); - out[idx_row] += Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols); - } -} - -// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial -// dot product + init (if kInit), which avoids having to zero-initialize and -// accumulate. -template -HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, - size_t mat_ofs, size_t r0, size_t c0, - size_t num_rows, size_t num_cols, - const VecT* HWY_RESTRICT vec_aligned, - const InitT* HWY_RESTRICT init, - float* HWY_RESTRICT out) { - for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { - const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat.Stride(); - if constexpr (kInit) { - out[idx_row] = hwy::ConvertScalarTo(init[idx_row + r0]) + - Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols); - } else { - out[idx_row] = Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols); - } - } -} - -// Adds together partial dot products for all tiles with the same r0 (a -// horizontal strip of the entire matrix); the result is the full dot product -// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we -// store into in out[r - r0]. -template -HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, - size_t mat_ofs, size_t r0, - size_t num_rows, size_t num_cols, - const VecT* HWY_RESTRICT vec_aligned, - const AddT* HWY_RESTRICT add, - float* HWY_RESTRICT out) { - HWY_DASSERT(num_cols <= mat.Cols()); - // Tall and skinny: set `out` to the single dot product. - if (num_cols < MaxCols()) { - SetFirstPartialDotProducts(df, mat, mat_ofs, r0, 0, num_rows, - num_cols, vec_aligned, add, out); - return; - } - - // We have at least MaxCols, so start by setting `out` to that: - SetFirstPartialDotProducts(df, mat, mat_ofs, r0, 0, num_rows, MaxCols(), - vec_aligned, add, out); - // For further multiples of MaxCols, accumulate. Remainders handled below. - size_t c0 = MaxCols(); - for (; c0 <= num_cols - MaxCols(); c0 += MaxCols()) { - AccumulatePartialDotProducts(df, mat, mat_ofs, r0, c0, num_rows, MaxCols(), - vec_aligned, out); - } - - if (c0 < num_cols) { // Final cols - AccumulatePartialDotProducts(df, mat, mat_ofs, r0, c0, num_rows, - num_cols - c0, vec_aligned, out); - } -} - -} // namespace detail - -// Stores dot products of rows with `vec_aligned` + add the values from `add` -// (if kAdd), then stores them to `out`. -template -HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs, - const size_t outer, const size_t inner, - const VecT* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - PROFILER_ZONE("MatVecAdd"); - - const hn::ScalableTag df; - const size_t rows_per_strip = RowsPerStrip(outer); - const size_t num_strips = outer / rows_per_strip; - - // For each entire strip. - pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * rows_per_strip; - detail::FullDotProductsForStrip(df, mat, mat_ofs, r0, rows_per_strip, - inner, vec_aligned, add, out + r0); - }); - - // Remaining rows - const size_t r0 = num_strips * rows_per_strip; - if (r0 < outer) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = outer - r0; - detail::FullDotProductsForStrip(df, mat, mat_ofs, r0, num_rows, inner, - vec_aligned, add, out + r0); - } -} - -// With addition -template -HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const size_t outer, const size_t inner, - const VecT* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - return MatVecT(mat, mat_ofs, outer, inner, vec_aligned, add, - out, pool); -} - -// Without addition -template -HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, - const size_t outer, const size_t inner, - const VecT* HWY_RESTRICT const vec_aligned, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - MatVecT(mat, mat_ofs, outer, inner, vec_aligned, - /*add=*/static_cast(nullptr), out, pool); -} - -// Two matrices, same vector -template -HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1, - const size_t mat_ofs, size_t outer, size_t inner, - const VecT* HWY_RESTRICT vec_aligned, - const AddT* HWY_RESTRICT add0, - const AddT* HWY_RESTRICT add1, - float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, - hwy::ThreadPool& pool) { - PROFILER_ZONE("TwoMatVecAdd"); - - const hn::ScalableTag df; - const size_t rows_per_strip = RowsPerStrip(outer); - const size_t num_strips = outer / rows_per_strip; - - // For each entire strip. - pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("TwoMatVec.lambda"); - const size_t r0 = strip * rows_per_strip; - detail::FullDotProductsForStrip(df, mat0, mat_ofs, r0, rows_per_strip, - inner, vec_aligned, add0, out0 + r0); - detail::FullDotProductsForStrip(df, mat1, mat_ofs, r0, rows_per_strip, - inner, vec_aligned, add1, out1 + r0); - }); - - // Remaining rows - const size_t r0 = num_strips * rows_per_strip; - if (r0 < outer) { - PROFILER_ZONE("TwoMatVec remainder"); - const size_t num_rows = outer - r0; - detail::FullDotProductsForStrip(df, mat0, mat_ofs, r0, num_rows, - inner, vec_aligned, add0, out0 + r0); - detail::FullDotProductsForStrip(df, mat1, mat_ofs, r0, num_rows, - inner, vec_aligned, add1, out1 + r0); - } -} - -// With addition -template -HWY_NOINLINE void TwoMatVecAdd( - const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs, - const size_t outer, const size_t inner, - const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0, - const AddT* HWY_RESTRICT add1, float* HWY_RESTRICT out0, - float* HWY_RESTRICT out1, hwy::ThreadPool& pool) { - return TwoMatVecT(mat0, mat1, mat_ofs, outer, inner, - vec_aligned, add0, add1, out0, out1, pool); -} - -// Without addition -template -HWY_NOINLINE void TwoMatVec(const ArrayT1& mat0, const ArrayT2& mat1, - const size_t mat_ofs, const size_t outer, - const size_t inner, - const VecT* HWY_RESTRICT vec_aligned, - float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, - hwy::ThreadPool& pool) { - TwoMatVecT( - mat0, mat1, mat_ofs, outer, inner, vec_aligned, /*add0=*/nullptr, - /*add1=*/nullptr, out0, out1, pool); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT From fac8aac4cbf0b3cdcaf891a0bca94716121afb55 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 22 Sep 2025 05:36:32 -0700 Subject: [PATCH 47/65] Internal change PiperOrigin-RevId: 809975026 --- BUILD.bazel | 1 + gemma/gemma-inl.h | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index d5bac73..8b0dcde 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -559,6 +559,7 @@ cc_library( "//io", "//io:blob_store", "//paligemma:image", + "@highway//:bit_set", "@highway//:hwy", "@highway//:nanobenchmark", # timer "@highway//:profiler", diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index a7f1b01..669d7e7 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -107,8 +107,10 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1, } } -#else +#endif // GEMMA_FUSED_FFN +// Only used if !GEMMA_FUSED_FFN, but define anyway so that we can check +// using if constexpr rather than #if, which interferes with code folding. template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, @@ -131,8 +133,6 @@ HWY_NOINLINE void ActivationBatched( } } -#endif // GEMMA_FUSED_FFN - template HWY_NOINLINE void ResidualConnection(const MatPtrT& other, MatPtrT& HWY_RESTRICT x, From 4f0c633248f69f0e572c5eee2739641e90f65a61 Mon Sep 17 00:00:00 2001 From: Charles Zhao Date: Tue, 23 Sep 2025 17:01:56 -0700 Subject: [PATCH 48/65] (1) Added QueryResultAndMetrics and BatchQueryModelWithMetrics to also return TimingInfo besides query results. PiperOrigin-RevId: 810634261 --- evals/benchmark_helper.cc | 27 +++++++++++++++++++-------- evals/benchmark_helper.h | 31 ++++++++++++++++++++++++------- gemma/gemma.h | 2 ++ 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index abdef50..a495dea 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -78,16 +78,16 @@ QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { << runtime_config_.max_generated_tokens << "\ttemperature: " << runtime_config_.temperature << "\n"; } - gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; + gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity}; runtime_config_.batch_stream_token = batch_stream_token; gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, timing_info); return result; } -void GemmaEnv::QueryModel( - const std::vector& tokens, const StreamFunc& stream_token) { - gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; +void GemmaEnv::QueryModel(const std::vector& tokens, + const StreamFunc& stream_token) { + gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity}; const StreamFunc previous_stream_token = runtime_config_.stream_token; runtime_config_.stream_token = stream_token; gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, @@ -95,7 +95,7 @@ void GemmaEnv::QueryModel( runtime_config_.stream_token = previous_stream_token; } -std::vector GemmaEnv::BatchQueryModel( +QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics( const QueriesPromptTokens& queries_prompt, const hwy::Span& prefix_end) { const size_t num_queries = queries_prompt.size(); @@ -140,7 +140,13 @@ std::vector GemmaEnv::BatchQueryModel( gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end); gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info); - return res; + return {res, timing_info}; +} + +std::vector GemmaEnv::BatchQueryModel( + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end) { + return BatchQueryModelWithMetrics(queries_prompt, prefix_end).query_results; } QueryResult GemmaEnv::QueryModel(const std::string& input) { @@ -148,7 +154,7 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) { return QueryModel(prompt); } -std::vector GemmaEnv::BatchQueryModel( +QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics( const std::vector& prompt_strings) { std::vector views; views.reserve(prompt_strings.size()); @@ -161,7 +167,12 @@ std::vector GemmaEnv::BatchQueryModel( } QueriesPromptTokens span_of_views(views.data(), views.size()); - return BatchQueryModel(span_of_views); + return BatchQueryModelWithMetrics(span_of_views); +} + +std::vector GemmaEnv::BatchQueryModel( + const std::vector& inputs) { + return BatchQueryModelWithMetrics(inputs).query_results; } float GemmaEnv::CrossEntropy(const std::string& input) { diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 75cf0d2..2380dbf 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -39,6 +39,14 @@ struct QueryResult { size_t response_start_pos = 0; }; +// Return type for batch query model calls with metrics. +struct QueryResultAndMetrics { + // The query results for each query in the batch. + std::vector query_results; + // The timing information for the batch query. + TimingInfo timing_info; +}; + // Convenience class to load a model and run inference. class GemmaEnv { public: @@ -79,21 +87,30 @@ class GemmaEnv { return string; } + // Adds turn structure to input, tokenizes and calls the below overload. + QueryResult QueryModel(const std::string& input); // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. QueryResult QueryModel(const std::vector& tokens); + // Runs inference on the given input and calls the callback for each token. + void QueryModel(const std::vector& tokens, + const StreamFunc& stream_token); + + // Similar to the above, but runs inference on a batch of inputs. + std::vector BatchQueryModel( + const std::vector& inputs); // The default prefix_end means "causal attention". std::vector BatchQueryModel( const QueriesPromptTokens& queries_prompt, const hwy::Span& prefix_end = hwy::Span()); - // Adds turn structure to input, tokenizes and calls the above overload. - QueryResult QueryModel(const std::string& input); - std::vector BatchQueryModel( - const std::vector& prompt_strings); - // Runs inference on the given input and calls the callback for each token. - void QueryModel(const std::vector& tokens, - const StreamFunc& stream_token); + // Similar to the above, but returns timing information in addition to the + // query results. + QueryResultAndMetrics BatchQueryModelWithMetrics( + const std::vector& prompt_strings); + QueryResultAndMetrics BatchQueryModelWithMetrics( + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end = hwy::Span()); // Runs inference on the given input and returns the cross entropy, a measure // of how well the model predicts the correct output. It is the average diff --git a/gemma/gemma.h b/gemma/gemma.h index 491999d..5e40bda 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -178,6 +178,7 @@ struct TimingInfo { // be sure to populate prefill_start and generate_start before calling // NotifyGenerated. void NotifyGenerated(size_t batch_size) { + generation_steps += 1; const bool is_first = (tokens_generated == 0); tokens_generated += batch_size; if (HWY_UNLIKELY(is_first)) { @@ -224,6 +225,7 @@ struct TimingInfo { double time_to_first_token = 0; double generate_duration = 0; size_t tokens_generated = 0; + size_t generation_steps = 0; }; // After construction, all methods are const and thread-compatible if using From d15731d2019fef5c77a392027c99f91d486665d7 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 25 Sep 2025 09:41:30 -0700 Subject: [PATCH 49/65] Used hn::BroadcastLane instead of Set(..., x.raw) PiperOrigin-RevId: 811386295 --- ops/ops-inl.h | 215 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 129 insertions(+), 86 deletions(-) diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 1593aa4..ec73f66 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -621,52 +621,116 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( }); } -template > +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, + VF& sum5, VF& sum6, VF& sum7, VF& sum8, + VF& sum9, VF& sum10, VF& sum11, + VF& sum12, VF& sum13, VF& sum14, + VF& sum15) { + sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); + sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); + sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); + sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); + sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); + sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); + sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); + sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); + sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale)); + sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale)); + sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale)); + sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale)); + sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale)); + sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale)); + sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale)); + sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale)); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, + VF& sum5, VF& sum6, VF& sum7, VF& sum8, + VF& sum9, VF& sum10, VF& sum11, + VF& sum12, VF& sum13, VF& sum14, + VF& sum15) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, + VF& sum6, VF& sum7) { + sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); + sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); + sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); + sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); + sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); + sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); + sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); + sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, + VF& sum6, VF& sum7) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 63)> HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) { - sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); - sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); - sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); - sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); - sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); - sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); - sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); - sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); - sum8 = hn::MulAdd(common, hn::Set(df, split.raw[8]), sum8); - sum9 = hn::MulAdd(common, hn::Set(df, split.raw[9]), sum9); - sum10 = hn::MulAdd(common, hn::Set(df, split.raw[10]), sum10); - sum11 = hn::MulAdd(common, hn::Set(df, split.raw[11]), sum11); - sum12 = hn::MulAdd(common, hn::Set(df, split.raw[12]), sum12); - sum13 = hn::MulAdd(common, hn::Set(df, split.raw[13]), sum13); - sum14 = hn::MulAdd(common, hn::Set(df, split.raw[14]), sum14); - sum15 = hn::MulAdd(common, hn::Set(df, split.raw[15]), sum15); + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); + sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); + sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); + sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); + sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); + sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8); + sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9); + sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10); + sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11); + sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12); + sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13); + sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14); + sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15); } -template > +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( + DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, + VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, + VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { - sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); - sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); - sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); - sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); - sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); - sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); - sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); - sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); + sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); + sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); + sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); + sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); } +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& sum4, VF& sum5, VF& sum6, + VF& sum7) {} + template > HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { - sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); - sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); - sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); - sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); } // For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows @@ -706,22 +770,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( out13 = hn::Load(df, out + i + out_offsets[13]); out14 = hn::Load(df, out + i + out_offsets[14]); out15 = hn::Load(df, out + i + out_offsets[15]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); - out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); - out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); - out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); - out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); - out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); - out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); - out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); - out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); VF x0 = hn::Load(df, v.Row(pos[0]) + i); MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11, out12, out13, out14, out15); @@ -773,14 +823,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( out5 = hn::Load(df, out + i + out_offsets[5]); out6 = hn::Load(df, out + i + out_offsets[6]); out7 = hn::Load(df, out + i + out_offsets[7]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); VF x0 = hn::Load(df, v.Row(pos[0]) + i); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); VF x1 = hn::Load(df, v.Row(pos[1]) + i); @@ -812,10 +855,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( out1 = hn::Load(df, out + i + out_offsets[1]); out2 = hn::Load(df, out + i + out_offsets[2]); out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); VF x0 = hn::Load(df, v.Row(pos[0]) + i); MulAdd4(df, x0, c0, out0, out1, out2, out3); VF x1 = hn::Load(df, v.Row(pos[1]) + i); @@ -878,22 +921,22 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out13 = hn::Load(df, out + i + out_offsets[13]); out14 = hn::Load(df, out + i + out_offsets[14]); out15 = hn::Load(df, out + i + out_offsets[15]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); - out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); - out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); - out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); - out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); - out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); - out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); - out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); - out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); + out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); + out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); + out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); + out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); + out8 = hn::Mul(out8, hn::BroadcastLane<8>(scale)); + out9 = hn::Mul(out9, hn::BroadcastLane<9>(scale)); + out10 = hn::Mul(out10, hn::BroadcastLane<10>(scale)); + out11 = hn::Mul(out11, hn::BroadcastLane<11>(scale)); + out12 = hn::Mul(out12, hn::BroadcastLane<12>(scale)); + out13 = hn::Mul(out13, hn::BroadcastLane<13>(scale)); + out14 = hn::Mul(out14, hn::BroadcastLane<14>(scale)); + out15 = hn::Mul(out15, hn::BroadcastLane<15>(scale)); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11, out12, out13, out14, out15); @@ -923,14 +966,14 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out5 = hn::Load(df, out + i + out_offsets[5]); out6 = hn::Load(df, out + i + out_offsets[6]); out7 = hn::Load(df, out + i + out_offsets[7]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); + out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); + out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); + out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); + out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); hn::Store(out0, df, out + i + out_offsets[0]); @@ -947,10 +990,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out1 = hn::Load(df, out + i + out_offsets[1]); out2 = hn::Load(df, out + i + out_offsets[2]); out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd4(df, x0, c0, out0, out1, out2, out3); hn::Store(out0, df, out + i + out_offsets[0]); From 667a3f117af5e25a4bd16d4c17eef72f3731591b Mon Sep 17 00:00:00 2001 From: Nitin Gangahar Date: Fri, 26 Sep 2025 11:27:56 -0700 Subject: [PATCH 50/65] Utilize multiple cores to read weight batches. PiperOrigin-RevId: 811893059 --- gemma/weights.cc | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 425a752..fb59297 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -465,18 +465,20 @@ static void ReadBatches(const BlobReader& reader, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches"); // >5x speedup from parallel reads when cached. - ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) { - PROFILER_ZONE3(ctx.profiler, thread, zone); - const IOBatch& batch = batches[i]; - const std::string& key = reader.Keys()[batch.KeyIdx()]; - const uint64_t bytes_read = batch.Read(reader.file()); - if (bytes_read != batch.TotalBytes()) { - HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(), - static_cast(batch.Offset()), - static_cast(batch.TotalBytes()), - static_cast(bytes_read)); - } - }); + ParallelFor(ParallelismStrategy::kHierarchical, + batches.size(), ctx, /*cluster_idx=*/0, + [&](uint64_t task, size_t thread) { + PROFILER_ZONE3(ctx.profiler, thread, zone); + const IOBatch& batch = batches[task]; + const std::string& key = reader.Keys()[batch.KeyIdx()]; + const uint64_t bytes_read = batch.Read(reader.file()); + if (bytes_read != batch.TotalBytes()) { + HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", + key.c_str(), static_cast(batch.Offset()), + static_cast(batch.TotalBytes()), + static_cast(bytes_read)); + } + }); } // Aborts on error. Updates `mode` to the actual mode used. Returns mapped From 16536996d1037241e5d0a24aca8e7eb0a3153275 Mon Sep 17 00:00:00 2001 From: Nitin Gangahar Date: Mon, 29 Sep 2025 02:28:04 -0700 Subject: [PATCH 51/65] Remove less useful spammy log lines. PiperOrigin-RevId: 812694572 --- compression/compress-inl.h | 16 ---------------- compression/python/compression_clif_aux.cc | 3 --- 2 files changed, 19 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index e2b3bcc..18d8e35 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -55,12 +55,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -#ifdef HWY_IS_TEST -static constexpr bool kIsTest = true; -#else -static constexpr bool kIsTest = false; -#endif - // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; @@ -485,9 +479,6 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, } } - const bool want_bench = COMPRESS_STATS || !kIsTest; - const double t0 = want_bench ? hwy::platform::Now() : 0.0; - using Traits = CompressTraits>; constexpr size_t kBatch = 8192; const size_t num_batches = hwy::DivCeil(num, kBatch); @@ -502,13 +493,6 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, packed, packed_ofs + my_pos); }); - if (want_bench) { // Avoids log spam in tests - const double t1 = hwy::platform::Now(); - const double mb = static_cast(num) * sizeof(raw[0]) * 1E-6; - const double mbps = mb / (t1 - t0); - fprintf(stderr, "Compress %.1f MB/s\n", mbps); - } - if constexpr (COMPRESS_STATS) { for (size_t i = 1; i < work.tls.size(); ++i) { work.tls[0].stats.Assimilate(work.tls[i].stats); diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 2de1b67..5e729cc 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -87,9 +87,6 @@ class SbsWriterImpl : public ISbsWriter { return; } - fprintf(stderr, "Compressing %s (%zu x %zu = %zuM) to %s, please wait\n", - name, mat.Rows(), mat.Cols(), weights.size() / (1000 * 1000), - TypeName(TypeEnum())); HWY_ASSERT(weights.size() == mat.Extents().Area()); Compress(weights.data(), weights.size(), working_set_, mat.Span(), /*packed_ofs=*/0, pool); From 4974f248322af872d99102eda1a79964084eaaf8 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 30 Sep 2025 02:17:18 -0700 Subject: [PATCH 52/65] Fixed bug with softcap in single flash attention PiperOrigin-RevId: 813164938 --- gemma/flash_attention.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ba1de3e..c65c57f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -126,6 +126,10 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, PROFILER_ZONE3(p, worker, zone); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); float m = Dot(q, k.Row(pos_mod), k.Cols()); + if (float cap = activations.config.att_cap; cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + m = cap * std::tanh(m / cap); + } float d = 1.0f; // This is just a copy of the first token. MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker); From 2f6cbde8ff843614f6b60996d27d56b5ee04c604 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 30 Sep 2025 05:48:50 -0700 Subject: [PATCH 53/65] Added a smaller tile size to flash attention for smaller batch sizes PiperOrigin-RevId: 813226193 --- BUILD.bazel | 1 + gemma/attention.cc | 3 +- gemma/flash_attention.cc | 312 ++++++++++++++++++++++++++++------ gemma/flash_attention.h | 4 +- gemma/flash_attention_test.cc | 12 +- ops/ops-inl.h | 202 ++++++++++++++++++---- 6 files changed, 445 insertions(+), 89 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 8b0dcde..74f472f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -548,6 +548,7 @@ cc_library( ":gemma_args", ":kv_cache", ":mat", + ":matmul", ":matmul_env", ":model_store", ":ops", diff --git a/gemma/attention.cc b/gemma/attention.cc index e894981..f404674 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -358,7 +358,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); } else { - FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); + FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer, + activations, qbatch, env.ctx); } SumHeads(layer, activations, env); } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index c65c57f..b93b58f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -44,6 +44,7 @@ // After highway.h #include "compression/compress-inl.h" #include "gemma/attention.h" +#include "ops/matmul-inl.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); @@ -114,6 +115,27 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } } +// Handles a single v row of flash attention for a single q.k dot product. +void HWY_INLINE SingleFlashAttentionStep( + float x, float cap, float& old_max, float& old_d, + const float* HWY_RESTRICT v, const size_t v_cols, + float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { + if (cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + x = cap * std::tanh(x / cap); + } + float m = std::max(x, old_max); + x = std::exp(x - m); + float scale = old_d * std::exp(old_max - m); + old_d = x + scale; + old_max = m; + float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x *= one_over_d; + MulByConst(scale, att_out, v_cols, p, worker); + MulByConstAndAdd(x, v, att_out, v_cols, p, worker); +} + // Calculates the complete attention outputs for a single row of q. void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const float* HWY_RESTRICT q, const MatPtrT& k, @@ -136,21 +158,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = activations.div_seq_len.Remainder(pos); float x = Dot(q, k.Row(pos_mod), k.Cols()); - if (activations.config.att_cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. - x = activations.config.att_cap * - std::tanh(x / activations.config.att_cap); - } - float m_new = std::max(m, x); - float scale = d * std::exp(m - m_new); - x = std::exp(x - m_new); - m = m_new; - d = scale + x; - float one_over_d = 1.0f / d; - x *= one_over_d; - scale *= one_over_d; - MulByConst(scale, att_out, v.Cols(), p, worker); - MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker); + SingleFlashAttentionStep(x, activations.config.att_cap, m, d, + v.Row(pos_mod), v.Cols(), att_out, p, worker); } } @@ -167,7 +176,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, return hn::LoadU(df, results); } -// Returns an 8xNF tile of Q.K dot products, in single precision. +// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single +// precision. // This is the result of NF rows of Q against 8 K timesteps, with positions // given by k_pos[0..7]. Q has been transposed so that the NF rows are read in // consecutive elements, and other columns by adding q_stride. @@ -240,8 +250,9 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, return hn::Add(sum0, sum2); } -// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then -// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos]. +// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to +// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, +// max_last_pos]. void TileFlashAttention( const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, const StridedView& qT, const MatPtrT& k, @@ -260,7 +271,7 @@ void TileFlashAttention( using DI = hn::ScalableTag; const DI di; using VI = hn::Vec; - const int kVTileSize = hn::MaxLanes(df); + const int kVTileSize = hn::Lanes(df); for (int i = 0; i < kVTileSize; ++i) { hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], v.Cols() * sizeof(att_out.Row(0)[0])); @@ -348,38 +359,217 @@ void TileFlashAttention( } } +// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. +// This is the result of 4 rows of Q against NF K timesteps, with positions +// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read +// in consecutive elements, and other columns by adding q_stride. +template > +void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const int32_t* HWY_RESTRICT k_offsets, + hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { + sum0 = hn::Zero(df); + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); + const float* HWY_RESTRICT k_base = k.Row(0); + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + VI k_offsets_vec = hn::LoadU(di, k_offsets); + for (size_t i = 0; i < k.Cols(); ++i) { + VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); + VF q_0 = hn::Set(df, q[0]); + sum0 = hn::MulAdd(q_0, k_vec, sum0); + VF q_1 = hn::Set(df, q[1]); + sum1 = hn::MulAdd(q_1, k_vec, sum1); + VF q_2 = hn::Set(df, q[2]); + sum2 = hn::MulAdd(q_2, k_vec, sum2); + VF q_3 = hn::Set(df, q[3]); + sum3 = hn::MulAdd(q_3, k_vec, sum3); + q += q_stride; + } +} + +// Handles NF v rows of flash attention for NF q.k dot products from one q row. +template > +float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, + float& old_d) { + float m = hn::ReduceMax(df, x); + m = std::max(m, old_max); + x = hn::Exp(df, x - hn::Set(df, m)); + float scale = old_d * std::exp(old_max - m); + old_d = hn::ReduceSum(df, x) + scale; + old_max = m; + float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x = hn::Mul(x, hn::Set(df, one_over_d)); + return scale; +} + +// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to +// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, +// max_last_pos]. +void TileFlashAttention4( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, + const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, + const size_t min_last_pos, const size_t max_last_pos, + const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, const AttentionActivations& activations, + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4"); + PROFILER_ZONE3(p, worker, zone); + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + constexpr size_t kMaxNF = hn::MaxLanes(df); + const size_t kHTileSize = hn::Lanes(df); + HWY_DASSERT(kHTileSize <= kMaxNF); + constexpr size_t kVTileSize = 4; + float scales[kVTileSize]; + for (size_t i = 0; i < kVTileSize; ++i) { + hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], + v.Cols() * sizeof(att_out.Row(0)[0])); + } + float old_m0 = -std::numeric_limits::max() / 2.0f; + float old_m1 = -std::numeric_limits::max() / 2.0f; + float old_m2 = -std::numeric_limits::max() / 2.0f; + float old_m3 = -std::numeric_limits::max() / 2.0f; + float old_d0 = 0.0f; + float old_d1 = 0.0f; + float old_d2 = 0.0f; + float old_d3 = 0.0f; + const float* HWY_RESTRICT qT_row = qT.Row(0); + const size_t qT_stride = qT.Stride(); + size_t position = start_pos; + while (position + kHTileSize - 1 <= min_last_pos) { + int32_t k_offsets[kMaxNF]; + size_t v_pos[kMaxNF]; + for (size_t i = 0; i < kHTileSize; ++i) { + v_pos[i] = activations.div_seq_len.Remainder(position + i); + k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); + } + VF x0, x1, x2, x3; + QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); + x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); + x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + } + scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0); + scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1); + scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); + scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); + MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), + out_offsets, v.Cols(), p, worker); + position += kHTileSize; + } + while (position <= max_last_pos) { + size_t k_pos = activations.div_seq_len.Remainder(position); + if (position <= last_pos[0]) { + // Past the last position, x0 doesn't count. + float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[0], p, worker); + } + if (position <= last_pos[1]) { + // Past the last position, x1 doesn't count. + float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[1], p, worker); + } + if (position <= last_pos[2]) { + // Past the last position, x2 doesn't count. + float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[2], p, worker); + } + if (position <= last_pos[3]) { + // Past the last position, x3 doesn't count. + float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[3], p, worker); + } + ++position; + } +} + +// Rounds n to a number that can be used as the number of Q rows in a tile +// of flash attention. +static size_t RoundToSuitablePowerOf2(size_t n) { + if (n < 4) return 1; + if (n < 8) return 4; + if (n < 16) return 8; + if (n < 32) return 16; + return 32; +} + // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] // into a single output O[L,D]. // Conventional attention first computes A[L,L] = Q . KT // followed by A = softmax(A) (over invididual rows). // Then A is multiplied by V to get O[L,D]. // For each row of O, this takes a read of one row of Q L times, all of K, -// 3 write/reads of one row of A, read all of V, an read.write the one row of O +// 3 write/reads of one row of A, read all of V, and read/write the one row of O // L times. Ignoring the computation for now, and focusing just on memory, // the one row of O takes L(4D+3) reads and L(D+3) writes. // For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. // -// Flash attention fuses these operations together, and (where possible) -// computes NF rows of the result using 8 accumulator registers and two more to -// keep running results. NF is the number of float lanes in a register, being 16 -// for AVX3. The softmax is converted to streaming form using the -// algortihm from: +// Flash attention fuses these operations together, and has 3 operating modes: +// 1. NF rows of the result computed using tiles of registers of shape NFx8. +// 2. 4 rows of the result computed using tiles of registers of shape 4xNF. +// 3. One row (of Q and the result) at a time. +// In all cases the intermediate result (Q.KT) is never stored to memory. +// NF is the number of float lanes in a register, being 16 for AVX3. The softmax +// is converted to streaming form using the algorithm from: // https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. // Q is transposed to Q_T[D,L] to make the dot product computation efficient. -// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing -// reads of Q by 8 and reads of K by NF. The streaming softmax is computed -// entirely in registers, and a further NF registers to accumulate the results -// of the product of the softmax and V, reduce the number of reads of V by NF, -// and the reads/writes of O by 8. +// +// In mode 1: +// QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one +// go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is +// computed entirely in registers, and a further NF registers to accumulate the +// results of the product of the softmax and V, reduce the number of reads of V +// by NF, and the reads/writes of O by 8. // The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, // which on AVX3 is an overall reduction by about a factor of 10. +// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn +// prefill, since in other cases, there is either a single K timestep (prefill) +// or a single num_heads set of Q rows (decode). +// +// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile, +// reducing the reads of Q by NF, and the reads of K by 4. The softmax and +// accumulation of the result is done in registers, cutting the reads of V by 4. +// The reads/writes of O are reduced by a factor of NF. +// The overall reduction is limited by the need to use gather to load K. +// Transposing K would be possible, but is complicated by the wraparound. +// Mode 2 can be used in all cases when there are at least 4 attention heads, +// but it may be prefereable to use mode 3 when the batch size is small to +// maximise parallelism. +// +// In mode 3, a single row of Q is computed against a single K timestep at a +// time, using SingleFlashAttention. In this case there is no reduction in the +// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are +// still eliminated. // // A further complication is that real attention is not as simple as documented // in the paper and above. There are multiple query heads, differing KV, and // different sequence lengths, so a lot of the work in FlashAttention is making -// sure that a collection of q rows can use the TileFlashAttention path. -void FlashAttention(const size_t num_tokens, const size_t layer_idx, - const LayerWeightsPtrs& layer, +// sure that a collection of q rows with the same KV and sequence length are +// grouped together so that mode 1 or 2 can be used, and choosing which of the +// 3 modes to use for best efficiency. +void FlashAttention(const size_t num_tokens, const size_t target_parallelism, + const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); @@ -392,15 +582,28 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx, // A "head group" in the context of GQA refers to a collection of query // heads that share the same key and value heads. const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - - using DF = hn::ScalableTag; - const DF df; - constexpr size_t kVTileSize = hn::MaxLanes(df); const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); const size_t total_tasks = token_batch * layer_config.heads; + + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + constexpr size_t kMaxNF = hn::MaxLanes(df); + HWY_DASSERT(kNF <= kMaxNF); + // The vertical tile size is determined by the ability to use tiling and the + // target_parallelism. In practice the possible tile sizes in order of + // preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or + // 16. The final tile size is chosen to be the largest possible that allows + // for target_parallelism parallel tasks. + const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens); + const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1; + const size_t kVTileSize = + (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) + ? kNF + : std::min(kMinTileSize, kMaxEqualK); // q has shape [batch, qbatch][head, qkv_dim]. // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the // maximum possible number of consecutive columns have the same KV matrices. @@ -416,26 +619,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx, const auto func = [&](const size_t task, size_t worker) HWY_ATTR { PROFILER_ZONE3(ctx.profiler, worker, zone); // Offsets into original Q for each row in the tile. - uint32_t q_offsets[kVTileSize]; + uint32_t q_offsets[kMaxNF]; // Offsets into att_out for each row in the tile. - uint32_t out_offsets[kVTileSize]; + uint32_t out_offsets[kMaxNF]; // Start positions for each row in the tile. - size_t start_positions[kVTileSize]; + size_t start_positions[kMaxNF]; // Last positions for each row in the tile. Inclusive. - uint32_t last_pos[kVTileSize]; + uint32_t last_pos[kMaxNF]; // min and max last positions across all rows in the tile determines when // TileFlashAttention switches to single vector mode to handle the // ragged sequence lengths. size_t min_last_pos = std::numeric_limits::max(); size_t max_last_pos = 0; // Indices into the qbatch.KV for each row in the tile. - size_t qi_indices[kVTileSize]; + size_t qi_indices[kMaxNF]; // Indices into the kv_cache for each row in the tile. - size_t kv_offsets[kVTileSize]; + size_t kv_offsets[kMaxNF]; // first_task is [qbatch, head, token]. const size_t first_task = task * kVTileSize; const size_t last_task = first_task + kVTileSize - 1; - bool use_tile_attention = last_task < total_tasks; + bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks; for (size_t offset = 0; offset < kVTileSize && first_task + offset < total_tasks; ++offset) { const size_t batch_idx = div_tokens.Remainder(first_task + offset); @@ -486,15 +689,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx, kv_cache.Stride()); if (use_tile_attention) { // To avoid duplicating the code to setup K and V, the call to - // TileFlashAttention is inside the loop over tasks, even thought it + // TileFlashAttention is inside the loop over tasks, even though it // handles all rows in the task at once. StridedView qT = StridedView(activations.q_T.Row(0) + first_task, kVTileSize, activations.q_T.Stride()); - TileFlashAttention( - activations.q, q_offsets, qT, k, start_positions[offset], last_pos, - min_last_pos, max_last_pos, v, layer_idx, layer, activations, - activations.att_out, out_offsets, ctx.profiler, worker); + if (kVTileSize == kNF) { + TileFlashAttention(activations.q, q_offsets, qT, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, + worker); + } else if (kVTileSize == 4) { + TileFlashAttention4(activations.q, q_offsets, qT, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, + worker); + } else { + HWY_UNREACHABLE; + } break; } else { SingleFlashAttention(start_positions[offset], last_pos[offset], diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index b505d6f..75e087a 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -42,8 +42,8 @@ namespace gcpp { float* HWY_RESTRICT att_out, hwy::Profiler& p, \ size_t worker); \ \ - void FlashAttention(size_t num_tokens, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ AttentionActivations& activations, QBatch& qbatch, \ ThreadingContext& ctx); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index d4d6380..7f8f31e 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -98,13 +98,14 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { } } -void TestAttention() { +void TestFlashAttention(size_t target_parallelism) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); // hwy::ThreadPool& pool = ctx.pools.Pool(); constexpr size_t kOuter = 1024; constexpr size_t kInner = 256; ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); + config.att_cap = 1024.0f; TensorInfoRegistry tensor_info_registry(config); const LayerConfig& layer_config = config.layer_configs[0]; const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); @@ -149,10 +150,17 @@ void TestAttention() { // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); - FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx); + FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, + qbatch, ctx); AssertClose(attention.att_out, *saved_att); } +void TestAttention() { + TestFlashAttention(8192); + TestFlashAttention(2048); + TestFlashAttention(256); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/ops-inl.h b/ops/ops-inl.h index ec73f66..a52c788 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -747,7 +747,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); size_t i = 0; while (i + NF <= size) { @@ -882,8 +882,162 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( } i += NF; } - const size_t remaining = size - i; - HWY_DASSERT(remaining == 0); + HWY_DASSERT(size == i); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0, + const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, + VF& sum3) { + sum0 = hn::MulAdd(common, c0, sum0); + sum1 = hn::MulAdd(common, c1, sum1); + sum2 = hn::MulAdd(common, c2, sum2); + sum3 = hn::MulAdd(common, c3, sum3); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, + const VF c1, const VF c2, + const VF c3, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { + // TODO(rays): Check whether a transpose of c0-c3 is applicable and faster. + VF x0 = hn::Load(df, v.Row(pos[0]) + offset); + MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1), + hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2, + sum3); + VF x1 = hn::Load(df, v.Row(pos[1]) + offset); + MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1), + hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2, + sum3); + VF x2 = hn::Load(df, v.Row(pos[2]) + offset); + MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1), + hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2, + sum3); + VF x3 = hn::Load(df, v.Row(pos[3]) + offset); + MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1), + hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2, + sum3); +} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + VF x4 = hn::Load(df, v.Row(pos[4]) + offset); + MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1), + hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2, + sum3); + VF x5 = hn::Load(df, v.Row(pos[5]) + offset); + MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1), + hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2, + sum3); + VF x6 = hn::Load(df, v.Row(pos[6]) + offset); + MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1), + hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2, + sum3); + VF x7 = hn::Load(df, v.Row(pos[7]) + offset); + MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1), + hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2, + sum3); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + VF x8 = hn::Load(df, v.Row(pos[8]) + offset); + MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1), + hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2, + sum3); + VF x9 = hn::Load(df, v.Row(pos[9]) + offset); + MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1), + hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2, + sum3); + VF x10 = hn::Load(df, v.Row(pos[10]) + offset); + MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1), + hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1, + sum2, sum3); + VF x11 = hn::Load(df, v.Row(pos[11]) + offset); + MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1), + hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1, + sum2, sum3); + VF x12 = hn::Load(df, v.Row(pos[12]) + offset); + MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1), + hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1, + sum2, sum3); + VF x13 = hn::Load(df, v.Row(pos[13]) + offset); + MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1), + hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1, + sum2, sum3); + VF x14 = hn::Load(df, v.Row(pos[14]) + offset); + MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1), + hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1, + sum2, sum3); + VF x15 = hn::Load(df, v.Row(pos[15]) + offset); + MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1), + hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1, + sum2, sum3); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} + +// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows +// of V by the corresponding values in c0-c3 and adds them to NF rows of out, +// after first prescaling out by scale. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( + DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, + const VF c2, const VF c3, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4"); + PROFILER_ZONE3(p, worker, zone); + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + + size_t i = 0; + while (i + NF <= size) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scales[0])); + out1 = hn::Mul(out1, hn::Set(df, scales[1])); + out2 = hn::Mul(out2, hn::Set(df, scales[2])); + out3 = hn::Mul(out3, hn::Set(df, scales[3])); + MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); + if HWY_LANES_CONSTEXPR (NF >= 8) { + MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); + if HWY_LANES_CONSTEXPR (NF >= 16) { + MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, + out3); + } + } + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + i += NF; + } + HWY_DASSERT(size == i); } // Prescales NF rows of out by scale, then multiplies 1 row of V by the @@ -898,11 +1052,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const size_t NF = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); size_t i = 0; while (i + NF <= size) { - if constexpr (NF == 16) { + if HWY_LANES_CONSTEXPR (NF == 16) { VF out0, out1, out2, out3, out4, out5, out6, out7; VF out8, out9, out10, out11, out12, out13, out14, out15; out0 = hn::Load(df, out + i + out_offsets[0]); @@ -921,22 +1075,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out13 = hn::Load(df, out + i + out_offsets[13]); out14 = hn::Load(df, out + i + out_offsets[14]); out15 = hn::Load(df, out + i + out_offsets[15]); - out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); - out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); - out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); - out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); - out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); - out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); - out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); - out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); - out8 = hn::Mul(out8, hn::BroadcastLane<8>(scale)); - out9 = hn::Mul(out9, hn::BroadcastLane<9>(scale)); - out10 = hn::Mul(out10, hn::BroadcastLane<10>(scale)); - out11 = hn::Mul(out11, hn::BroadcastLane<11>(scale)); - out12 = hn::Mul(out12, hn::BroadcastLane<12>(scale)); - out13 = hn::Mul(out13, hn::BroadcastLane<13>(scale)); - out14 = hn::Mul(out14, hn::BroadcastLane<14>(scale)); - out15 = hn::Mul(out15, hn::BroadcastLane<15>(scale)); + Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11, out12, out13, out14, out15); @@ -956,7 +1096,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( hn::Store(out13, df, out + i + out_offsets[13]); hn::Store(out14, df, out + i + out_offsets[14]); hn::Store(out15, df, out + i + out_offsets[15]); - } else if constexpr (NF == 8) { + } + if HWY_LANES_CONSTEXPR (NF == 8) { VF out0, out1, out2, out3, out4, out5, out6, out7; out0 = hn::Load(df, out + i + out_offsets[0]); out1 = hn::Load(df, out + i + out_offsets[1]); @@ -966,14 +1107,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out5 = hn::Load(df, out + i + out_offsets[5]); out6 = hn::Load(df, out + i + out_offsets[6]); out7 = hn::Load(df, out + i + out_offsets[7]); - out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); - out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); - out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); - out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); - out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); - out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); - out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); - out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); + Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); hn::Store(out0, df, out + i + out_offsets[0]); @@ -984,7 +1118,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( hn::Store(out5, df, out + i + out_offsets[5]); hn::Store(out6, df, out + i + out_offsets[6]); hn::Store(out7, df, out + i + out_offsets[7]); - } else if constexpr (NF == 4) { + } + if HWY_LANES_CONSTEXPR (NF == 4) { VF out0, out1, out2, out3; out0 = hn::Load(df, out + i + out_offsets[0]); out1 = hn::Load(df, out + i + out_offsets[1]); @@ -1000,13 +1135,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( hn::Store(out1, df, out + i + out_offsets[1]); hn::Store(out2, df, out + i + out_offsets[2]); hn::Store(out3, df, out + i + out_offsets[3]); - } else { - HWY_DASSERT(false); } i += NF; } - const size_t remaining = size - i; - HWY_DASSERT(remaining == 0); + HWY_DASSERT(size == i); } // See below for a specialized version for top-1 sampling. From 6098a022b3fca21b54065a70dd6e341134623860 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 1 Oct 2025 07:10:40 -0700 Subject: [PATCH 54/65] Increased parallelism for RMSNormAndPositionalEncoding PiperOrigin-RevId: 813738994 --- gemma/flash_attention.cc | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index b93b58f..ddf2bcc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -87,31 +87,32 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, static const auto zone = ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); const float query_scale = activations.query_scale; + const hwy::Divisor div_qbatch(qbatch.Size()); const auto func = [&](const size_t task, size_t worker) HWY_ATTR { PROFILER_ZONE3(ctx.profiler, worker, zone); - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - for (size_t h = 0; h < layer.layer_config.heads; ++h) { - const size_t tq_idx = qbatch.Size() * task + qi; - // Find the token position in the query and calculate - // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + task; - float* HWY_RESTRICT q_row = - q.Row(tq_idx) + h * layer.layer_config.qkv_dim; - // Apply rope and scaling to Q. - if (layer.query_norm_scale.HasPtr()) { - CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), q_row, - layer.layer_config.qkv_dim, ctx.profiler, worker); - }); - } - PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, - worker, pos, query_scale); + size_t qi = div_qbatch.Remainder(task); + size_t batch_idx = div_qbatch.Divide(task); + for (size_t h = 0; h < layer.layer_config.heads; ++h) { + const size_t tq_idx = qbatch.Size() * batch_idx + qi; + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + float* HWY_RESTRICT q_row = + q.Row(tq_idx) + h * layer.layer_config.qkv_dim; + // Apply rope and scaling to Q. + if (layer.query_norm_scale.HasPtr()) { + CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), q_row, + layer.layer_config.qkv_dim, ctx.profiler, worker); + }); } + PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, + worker, pos, query_scale); } }; { // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_tokens, ctx.pools, func); + HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func); } } From fe5a39990edc012f7b1a90b428a292f9ef450a55 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 2 Oct 2025 02:36:29 -0700 Subject: [PATCH 55/65] Improve FlashAttention threading: kFlat for RMSNorm (hierarchical is excessive), profiler zone naming improvements. PiperOrigin-RevId: 814144012 --- gemma/attention.cc | 15 +++++++++++---- gemma/flash_attention.cc | 13 ++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index f404674..576c0b7 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -252,7 +252,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, AttentionActivations& activations, const QBatch& qbatch, const int flags, MatMulEnv& env) { - PROFILER_ZONE("Gen.Attention.QKV"); + static const auto zone = env.ctx.profiler.AddZone( + "Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + const hwy::Divisor div_qbatch(qbatch.Size()); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); const LayerConfig& layer_config = layer.layer_config; @@ -325,7 +328,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, AttentionActivations& activations, MatMulEnv& env) { - PROFILER_ZONE("Gen.Attention.SumHeads"); + static const auto zone = env.ctx.profiler.AddZone( + "Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length @@ -358,8 +363,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); } else { - FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer, - activations, qbatch, env.ctx); + // * 2 does not help on Turin. + FlashAttention(num_tokens, + /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, + layer_idx, layer, activations, qbatch, env.ctx); } SumHeads(layer, activations, env); } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ddf2bcc..33ad725 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -73,8 +73,9 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, } }; { - // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(q_t.Rows(), ctx.pools, func); + // Better than kFlat. + ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx, + /*cluster_idx=*/0, func); } } @@ -111,8 +112,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } }; { - // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func); + // kHierarchical is not worth the extra sync overhead because the tasks are + // very lightweight. + ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, + /*cluster_idx=*/0, func); } } @@ -722,7 +725,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, }; { - PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + PROFILER_ZONE("Gen.FlashAttention.ForkJoin"); // Full parallelism is helpful, SmallParallelFor is insufficient. HierarchicalParallelFor(num_thread_tasks, ctx.pools, func); } From 14244664c8e962a420dee7d373a217fcbf7bd7f4 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 2 Oct 2025 05:16:03 -0700 Subject: [PATCH 56/65] Avoid transposing Q when it isn't needed PiperOrigin-RevId: 814187984 --- gemma/flash_attention.cc | 83 +++++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 33ad725..77a4480 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -51,6 +51,8 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +static constexpr size_t kNFx8HTileSize = 8; + // Transposes q into q_t. // Both are 4D tensors stuffed into a 2-D MatPtrT. // q has shape [batch, qbatch][head, qkv_dim]. @@ -191,7 +193,7 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { - constexpr size_t kHTileSize = 8; + constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); sum1 = hn::Zero(df); sum2 = hn::Zero(df); @@ -268,7 +270,7 @@ void TileFlashAttention( hwy::Profiler& p, const size_t worker) { static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); PROFILER_ZONE3(p, worker, zone); - constexpr int kHTileSize = 8; + constexpr int kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; @@ -365,13 +367,12 @@ void TileFlashAttention( // Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. // This is the result of 4 rows of Q against NF K timesteps, with positions -// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read -// in consecutive elements, and other columns by adding q_stride. +// given by k_offsets[0..NF]. template > -void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const int32_t* HWY_RESTRICT k_offsets, - hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, - VF& sum2, VF& sum3) { +void QDotKTilex4(DF df, const float* HWY_RESTRICT q, + const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, + const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p, + const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { sum0 = hn::Zero(df); sum1 = hn::Zero(df); sum2 = hn::Zero(df); @@ -383,15 +384,14 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, VI k_offsets_vec = hn::LoadU(di, k_offsets); for (size_t i = 0; i < k.Cols(); ++i) { VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); - VF q_0 = hn::Set(df, q[0]); + VF q_0 = hn::Set(df, q[q_offsets[0] + i]); sum0 = hn::MulAdd(q_0, k_vec, sum0); - VF q_1 = hn::Set(df, q[1]); + VF q_1 = hn::Set(df, q[q_offsets[1] + i]); sum1 = hn::MulAdd(q_1, k_vec, sum1); - VF q_2 = hn::Set(df, q[2]); + VF q_2 = hn::Set(df, q[q_offsets[2] + i]); sum2 = hn::MulAdd(q_2, k_vec, sum2); - VF q_3 = hn::Set(df, q[3]); + VF q_3 = hn::Set(df, q[q_offsets[3] + i]); sum3 = hn::MulAdd(q_3, k_vec, sum3); - q += q_stride; } } @@ -416,10 +416,9 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, // max_last_pos]. void TileFlashAttention4( const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const StridedView& qT, const MatPtrT& k, - const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, - const size_t min_last_pos, const size_t max_last_pos, - const MatPtrT& v, const size_t layer_idx, + const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, hwy::Profiler& p, const size_t worker) { @@ -445,8 +444,6 @@ void TileFlashAttention4( float old_d1 = 0.0f; float old_d2 = 0.0f; float old_d3 = 0.0f; - const float* HWY_RESTRICT qT_row = qT.Row(0); - const size_t qT_stride = qT.Stride(); size_t position = start_pos; while (position + kHTileSize - 1 <= min_last_pos) { int32_t k_offsets[kMaxNF]; @@ -456,7 +453,8 @@ void TileFlashAttention4( k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); } VF x0, x1, x2, x3; - QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3); + QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2, + x3); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -608,12 +606,29 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) ? kNF : std::min(kMinTileSize, kMaxEqualK); - // q has shape [batch, qbatch][head, qkv_dim]. - // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the - // maximum possible number of consecutive columns have the same KV matrices. - // Each thread will process a tile of NF columns of QT so the starting column - // index of QT is just the task index * kVTileSize. - TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + // Only transpose Q if we are using tiling. + if (kVTileSize == kNF) { + size_t max_last = 0, min_start = std::numeric_limits::max(); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + size_t pos = qbatch.Pos(qi); + const size_t start = StartPos(pos, activations.config, layer_idx); + pos += num_tokens - 1; + const size_t end = qbatch.PrefixEnd(qi); + if (end > 0 && end - 1 > pos) { + pos = end - 1; + } + max_last = std::max(max_last, pos); + min_start = std::min(min_start, start); + } + if (max_last - min_start + 1 >= kNFx8HTileSize) { + // q has shape [batch, qbatch][head, qkv_dim]. + // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the + // maximum possible number of consecutive columns have the same KV + // matrices. Each thread will process a tile of NF columns of QT so the + // starting column index of QT is just the task index * kVTileSize. + TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + } + } const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); const hwy::Divisor div_tokens(num_tokens); // All layers should have the same number of heads. @@ -699,17 +714,23 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, StridedView(activations.q_T.Row(0) + first_task, kVTileSize, activations.q_T.Stride()); if (kVTileSize == kNF) { + // We can still use TileFlashAttention even if we didn't transpose Q + // above. The condition used for transposing Q above is more general + // and easier to compute than the condition used within + // TileFlashAttention that min_last_pos - start_positions[offset] < + // kNFx8HTileSize. In this case, qT is never used. Some tasks might + // use qT and some might not, which is why the more general condition + // is used above to catch all cases where qT will be used. TileFlashAttention(activations.q, q_offsets, qT, k, start_positions[offset], last_pos, min_last_pos, max_last_pos, v, layer_idx, layer, activations, activations.att_out, out_offsets, ctx.profiler, worker); } else if (kVTileSize == 4) { - TileFlashAttention4(activations.q, q_offsets, qT, k, - start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, layer, activations, - activations.att_out, out_offsets, ctx.profiler, - worker); + TileFlashAttention4( + activations.q, q_offsets, k, start_positions[offset], last_pos, + min_last_pos, max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, worker); } else { HWY_UNREACHABLE; } From 684a0444e9bb6ddba53c361d87def055235ca387 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 2 Oct 2025 08:14:37 -0700 Subject: [PATCH 57/65] Reduced parallelism for TransposeQ, making each thread read and write within its own cache lines PiperOrigin-RevId: 814241032 --- gemma/flash_attention.cc | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 77a4480..548c1aa 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -61,22 +61,30 @@ static constexpr size_t kNFx8HTileSize = 8; static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, const size_t qbatch_size, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); + // Group floats by the number of floats in a cache line. + const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t num_heads = q.Cols() / q_t.Rows(); const size_t batch_size = q.Rows() / qbatch_size; const auto func = [&](const size_t task, size_t worker) HWY_ATTR { PROFILER_ZONE3(ctx.profiler, worker, zone); - float* HWY_RESTRICT qt_row = q_t.Row(task); - for (size_t qi = 0; qi < qbatch_size; ++qi) - for (size_t h = 0; h < num_heads; ++h) { - for (size_t b = 0; b < batch_size; ++b) { - qt_row[(qi * num_heads + h) * batch_size + b] = - q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task]; + for (size_t lane = 0; lane < kNF; ++lane) { + size_t q_row = task * kNF + lane; + if (q_row >= q_t.Rows()) break; + float* HWY_RESTRICT qt_row = q_t.Row(q_row); + for (size_t qi = 0; qi < qbatch_size; ++qi) { + for (size_t h = 0; h < num_heads; ++h) { + for (size_t b = 0; b < batch_size; ++b) { + qt_row[(qi * num_heads + h) * batch_size + b] = + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]; + } } } + } }; { // Better than kFlat. - ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx, + size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); + ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, /*cluster_idx=*/0, func); } } From 9dc802c7aac3048a93677e498e346002d25ead68 Mon Sep 17 00:00:00 2001 From: Nitin Gangahar Date: Mon, 6 Oct 2025 10:25:07 -0700 Subject: [PATCH 58/65] Add logging to io.cc on failed write and read. This should provide insights into any failures. PiperOrigin-RevId: 815784482 --- io/io.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/io/io.cc b/io/io.cc index d19ec10..9363b07 100644 --- a/io/io.cc +++ b/io/io.cc @@ -106,7 +106,13 @@ class FilePosix : public File { for (;;) { // pread seems to be faster than lseek + read when parallelized. const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos); - if (bytes_read <= 0) break; + if (bytes_read <= 0) { + HWY_WARN( + "Read failure at pos %zu within size %zu with offset %zu and " + "errno %d\n", + pos, size, offset, errno); + break; + } pos += bytes_read; HWY_ASSERT(pos <= size); if (pos == size) break; @@ -120,7 +126,13 @@ class FilePosix : public File { for (;;) { const auto bytes_written = pwrite(fd_, bytes + pos, size - pos, offset + pos); - if (bytes_written <= 0) break; + if (bytes_written <= 0) { + HWY_WARN( + "Write failure at pos %zu within size %zu with offset %zu and " + "errno %d\n", + pos, size, offset, errno); + break; + } pos += bytes_written; HWY_ASSERT(pos <= size); if (pos == size) break; From 035273c18423bac6d30a21c79a73e3e78ad59940 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 7 Oct 2025 08:35:44 -0700 Subject: [PATCH 59/65] tune pool kSpin mode in threading_context Previously, this happened concurrently with the matmul autotune, which could lead to incorrect outcomes. threading: de-singleton Pinning (no longer stores affinity); pass PoolWorkerMapping; fix Pool dtor order Also enable SPR target (Zen4 is AMD-only), update Highway version for renamed Thread()->GlobalIdx(). PiperOrigin-RevId: 816223017 --- BUILD.bazel | 1 + CMakeLists.txt | 2 +- MODULE.bazel | 2 +- README.md | 2 +- compression/types.h | 7 +- evals/cross_entropy.cc | 2 +- examples/hello_world/CMakeLists.txt | 2 +- examples/simplified_gemma/CMakeLists.txt | 2 +- gemma/api_server.cc | 3 +- gemma/attention.cc | 6 +- gemma/gemma-inl.h | 2 +- gemma/gemma.cc | 2 +- gemma/vit.cc | 3 +- util/threading.cc | 141 ++++++++++------------- util/threading.h | 40 ++++++- util/threading_context.cc | 38 +++--- util/threading_context.h | 2 +- 17 files changed, 143 insertions(+), 114 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 74f472f..f482e56 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -114,6 +114,7 @@ cc_library( "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:profiler", + "@highway//:thread_pool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 46242f6..5dc4e11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if diff --git a/MODULE.bazel b/MODULE.bazel index b6b5f78..e0ba1c7 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "1d16731233de45a365b43867f27d0a5f73925300", + commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee", remote = "https://github.com/google/highway", ) diff --git a/README.md b/README.md index 2963bf6..722c2a8 100644 --- a/README.md +++ b/README.md @@ -452,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_MakeAvailable(highway) ``` diff --git a/compression/types.h b/compression/types.h index 661bc42..c3be52a 100644 --- a/compression/types.h +++ b/compression/types.h @@ -45,10 +45,11 @@ namespace gcpp { // as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most. #define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE) #elif HWY_ARCH_X86 -// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs, -// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2. +// Skip anything older than Haswell (2013); use Zen4/SPR for recent CPUs. +// Although we do not use SPR's F16, Zen4 is only enabled for AMD. We do not +// yet use any AVX 10.2 features. #define GEMMA_DISABLED_TARGETS \ - (HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2) + (HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX10_2) #endif // HWY_ARCH_* #endif // GEMMA_DISABLED_TARGETS diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 49acb50..355f26d 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -84,7 +84,7 @@ namespace gcpp { namespace HWY_NAMESPACE { void CallSoftmax(Logits logits, hwy::Profiler& p) { - Softmax(logits, p, hwy::Profiler::Thread()); + Softmax(logits, p, hwy::Profiler::GlobalIdx()); } } // namespace HWY_NAMESPACE diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 7a63ace..65541d8 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index da111cc..710f5ee 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/api_server.cc b/gemma/api_server.cc index ea5377d..f05447b 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -376,8 +376,7 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& // Ensure all data is sent sink.done(); - - return false; // End streaming + return false; // End streaming } catch (const std::exception& e) { json error_event = {{"error", {{"message", e.what()}}}}; diff --git a/gemma/attention.cc b/gemma/attention.cc index 576c0b7..a77021a 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -254,7 +254,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, MatMulEnv& env) { static const auto zone = env.ctx.profiler.AddZone( "Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const hwy::Divisor div_qbatch(qbatch.Size()); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); @@ -330,7 +330,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, MatMulEnv& env) { static const auto zone = env.ctx.profiler.AddZone( "Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length @@ -350,7 +350,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, MatMulEnv& env, int flags) { static const auto zone = env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 669d7e7..bdf989a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -155,7 +155,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, Activations& activations, MatMulEnv& env) { static const auto zone = env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 778ecc6..c3e2bac 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -139,7 +139,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, size_t image_token_position = 0) { static const auto zone = ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone); // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && diff --git a/gemma/vit.cc b/gemma/vit.cc index 1910091..44b1bcb 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -335,7 +335,8 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0), - vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread()); + vit_model_dim, env.ctx.profiler, + hwy::Profiler::GlobalIdx()); }); } diff --git a/util/threading.cc b/util/threading.cc index 1001f05..9c4cfe0 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -19,7 +19,6 @@ #include #include // std::sort -#include #include #include #include @@ -29,7 +28,6 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/topology.h" -#include "hwy/profiler.h" namespace gcpp { @@ -41,85 +39,60 @@ static void SortByDescendingSize(std::vector& groups) { [](const T& a, const T& b) { return a.Size() > b.Size(); }); } -// Singleton, holds the original process affinity and the pinning status. -class Pinning { - static bool InContainer() { - return false; } +static bool InContainer() { + return false; // placeholder for container detection, do not remove +} - public: - void SetPolicy(Tristate pin) { - if (pin == Tristate::kDefault) { - // Pinning is unreliable inside containers because the hypervisor might - // periodically change our affinity mask, or other processes might also - // pin themselves to the same LPs. - pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; - } - want_pin_ = (pin == Tristate::kTrue); - any_error_.clear(); +PinningPolicy::PinningPolicy(Tristate pin) { + if (pin == Tristate::kDefault) { + // Pinning is unreliable inside containers because the hypervisor might + // periodically change our affinity mask, or other processes might also + // pin themselves to the same LPs. + pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; } + want_pin_ = (pin == Tristate::kTrue); +} - // If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`, - // and sets `any_error_` if any fails. - void MaybePin(const BoundedTopology& topology, size_t pkg_idx, - size_t cluster_idx, const BoundedTopology::Cluster& cluster, - hwy::ThreadPool& pool) { - const std::vector lps = cluster.LPVector(); - HWY_ASSERT(pool.NumWorkers() <= lps.size()); - pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task +// If `pinning.Want()`, tries to pin each worker in `pool` to an LP in +// `cluster`, and calls `pinning.NotifyFailed()` if any fails. +void MaybePin(const BoundedTopology& topology, size_t pkg_idx, + size_t cluster_idx, const BoundedTopology::Cluster& cluster, + PinningPolicy& pinning, hwy::ThreadPool& pool) { + const std::vector lps = cluster.LPVector(); + HWY_ASSERT(pool.NumWorkers() <= lps.size()); + pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task - char buf[16]; // Linux limitation - const int bytes_written = snprintf( - buf, sizeof(buf), "P%zu X%02zu C%03d", - topology.SkippedPackages() + pkg_idx, - topology.SkippedClusters() + cluster_idx, static_cast(task)); - HWY_ASSERT(bytes_written < static_cast(sizeof(buf))); - hwy::SetThreadName(buf, 0); // does not support varargs + char buf[16]; // Linux limitation + const int bytes_written = snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", + topology.SkippedPackages() + pkg_idx, + topology.SkippedClusters() + cluster_idx, + static_cast(task)); + HWY_ASSERT(bytes_written < static_cast(sizeof(buf))); + hwy::SetThreadName(buf, 0); // does not support varargs - if (HWY_LIKELY(want_pin_)) { - if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { - // Apple does not support pinning, hence do not warn there. - if (!HWY_OS_APPLE) { - HWY_WARN("Pinning failed for task %d of %zu to %zu (size %zu)\n", - static_cast(task), pool.NumWorkers(), lps[task], - lps.size()); - } - (void)any_error_.test_and_set(); + if (HWY_LIKELY(pinning.Want())) { + if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { + // Apple does not support pinning, hence do not warn there. + if (!HWY_OS_APPLE) { + HWY_WARN("Pinning failed for task %d of %zu to %zu (size %zu)\n", + static_cast(task), pool.NumWorkers(), lps[task], + lps.size()); } + pinning.NotifyFailed(); } - }); - } - - // Called ONCE after all MaybePin because it invalidates the error status. - bool AllPinned(const char** pin_string) { - // If !want_pin_, MaybePin will return without setting any_error_, but in - // that case we still want to return false to avoid spinning. - // .test() was only added in C++20, so we use .test_and_set() instead. - const bool all_pinned = want_pin_ && !any_error_.test_and_set(); - *pin_string = all_pinned ? "pinned" - : want_pin_ ? "pinning failed" - : "pinning skipped"; - return all_pinned; - } - - private: - std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; - bool want_pin_; // set in SetPolicy -}; // Pinning - -// Singleton saves global affinity across all BoundedTopology instances because -// pinning overwrites it. -static Pinning& GetPinning() { - static Pinning pinning; - return pinning; + } + }); } static PoolPtr MakePool(const Allocator& allocator, size_t num_workers, + hwy::PoolWorkerMapping mapping, std::optional node = std::nullopt) { // `ThreadPool` expects the number of threads to create, which is one less // than the number of workers, but avoid underflow if zero. const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1; - PoolPtr ptr = allocator.AllocClasses(1, num_threads); + PoolPtr ptr = + allocator.AllocClasses(1, num_threads, mapping); const size_t bytes = hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes()); if (node.has_value() && allocator.ShouldBind()) { @@ -142,10 +115,11 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) { NestedPools::NestedPools(const BoundedTopology& topology, const Allocator& allocator, size_t max_threads, - Tristate pin) { - GetPinning().SetPolicy(pin); + Tristate pin) + : pinning_(pin) { packages_.resize(topology.NumPackages()); - all_packages_ = MakePool(allocator, packages_.size()); + all_packages_ = + MakePool(allocator, packages_.size(), hwy::PoolWorkerMapping()); const size_t max_workers_per_package = DivideMaxAcross(max_threads, packages_.size()); // Each worker in all_packages_, including the main thread, will be the @@ -153,11 +127,11 @@ NestedPools::NestedPools(const BoundedTopology& topology, // `cluster.lps` if `pin`. all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) { HWY_ASSERT(pkg_idx == thread); // each thread has one task - packages_[pkg_idx] = - Package(topology, allocator, pkg_idx, max_workers_per_package); + packages_[pkg_idx] = Package(topology, allocator, pinning_, pkg_idx, + max_workers_per_package); }); - all_pinned_ = GetPinning().AllPinned(&pin_string_); + all_pinned_ = pinning_.AllPinned(&pin_string_); // For mapping package/cluster/thread to noncontiguous TLS indices, in case // cluster/thread counts differ. @@ -172,8 +146,6 @@ NestedPools::NestedPools(const BoundedTopology& topology, HWY_ASSERT(max_clusters_per_package_ <= 64); HWY_ASSERT(max_workers_per_cluster_ >= 1); HWY_ASSERT(max_workers_per_cluster_ <= 256); - - hwy::Profiler::Get().SetMaxThreads(MaxWorkers()); } // `max_or_zero` == 0 means no limit. @@ -182,15 +154,22 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { } NestedPools::Package::Package(const BoundedTopology& topology, - const Allocator& allocator, size_t pkg_idx, + const Allocator& allocator, + PinningPolicy& pinning, size_t pkg_idx, size_t max_workers_per_package) { // Pre-allocate because elements are set concurrently. clusters_.resize(topology.NumClusters(pkg_idx)); const size_t max_workers_per_cluster = DivideMaxAcross(max_workers_per_package, clusters_.size()); - all_clusters_ = MakePool(allocator, clusters_.size(), - topology.GetCluster(pkg_idx, 0).Node()); + const BoundedTopology::Cluster& cluster0 = topology.GetCluster(pkg_idx, 0); + // Core 0 of each cluster. The second argument is the cluster size, not + // number of clusters. We ensure that it is the same for all clusters so that + // the `GlobalIdx` computation is consistent within and across clusters. + const hwy::PoolWorkerMapping all_clusters_mapping(hwy::kAllClusters, + cluster0.Size()); + all_clusters_ = MakePool(allocator, clusters_.size(), all_clusters_mapping, + cluster0.Node()); // Parallel so we also pin the calling worker in `all_clusters` to // `cluster.lps`. all_clusters_->Run( @@ -198,12 +177,14 @@ NestedPools::Package::Package(const BoundedTopology& topology, HWY_ASSERT(cluster_idx == thread); // each thread has one task const BoundedTopology::Cluster& cluster = topology.GetCluster(pkg_idx, cluster_idx); + HWY_ASSERT(cluster.Size() == cluster0.Size()); clusters_[cluster_idx] = MakePool( allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster), + hwy::PoolWorkerMapping(cluster_idx, cluster.Size()), cluster.Node()); // Pin workers AND the calling thread from `all_clusters`. - GetPinning().MaybePin(topology, pkg_idx, cluster_idx, cluster, - *clusters_[cluster_idx]); + MaybePin(topology, pkg_idx, cluster_idx, cluster, pinning, + *clusters_[cluster_idx]); }); } diff --git a/util/threading.h b/util/threading.h index 5dde114..53795be 100644 --- a/util/threading.h +++ b/util/threading.h @@ -19,6 +19,7 @@ #include #include +#include #include // IWYU pragma: begin_exports @@ -40,6 +41,30 @@ namespace gcpp { // moving because it is a typedef to `std::unique_ptr`. using PoolPtr = AlignedClassPtr; +class PinningPolicy { + public: + explicit PinningPolicy(Tristate pin); + + bool Want() const { return want_pin_; } + void NotifyFailed() { (void)any_error_.test_and_set(); } + + // Called ONCE after all MaybePin because it invalidates the error status. + bool AllPinned(const char** pin_string) { + // If !want_pin_, MaybePin will return without setting any_error_, but in + // that case we still want to return false to avoid spinning. + // .test() was only added in C++20, so we use .test_and_set() instead. + const bool all_pinned = want_pin_ && !any_error_.test_and_set(); + *pin_string = all_pinned ? "pinned" + : want_pin_ ? "pinning failed" + : "pinning skipped"; + return all_pinned; + } + + private: + std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; + bool want_pin_; // set in SetPolicy +}; // PinningPolicy + // Creates a hierarchy of thread pools according to `BoundedTopology`: one with // a thread per enabled package; for each of those, one with a thread per // enabled cluster (CCX/shared L3), and for each of those, the remaining @@ -56,7 +81,12 @@ using PoolPtr = AlignedClassPtr; // Useful when there are tasks which should be parallelized by workers sharing a // cache, or on the same NUMA node. In both cases, individual pools have lower // barrier synchronization latency than one large pool. However, to utilize all -// cores, call sites will have to use nested parallel-for loops. +// cores, call sites will have to use nested parallel-for loops as in +// `HierarchicalParallelFor`. To allow switching modes easily, prefer using the +// `ParallelFor` abstraction in threading_context.h). +// +// Note that this was previously intended to use all cores, but we are now +// moving toward also allowing concurrent construction with subsets of cores. class NestedPools { public: // Neither move nor copy. @@ -151,7 +181,8 @@ class NestedPools { public: Package() = default; // for vector Package(const BoundedTopology& topology, const Allocator& allocator, - size_t pkg_idx, size_t max_workers_per_package); + PinningPolicy& pinning, size_t pkg_idx, + size_t max_workers_per_package); size_t NumClusters() const { return clusters_.size(); } size_t MaxWorkersPerCluster() const { @@ -184,8 +215,10 @@ class NestedPools { } private: - std::vector clusters_; + // Must be freed after `clusters_` because it reserves threads which are + // the main threads of `clusters_`. PoolPtr all_clusters_; + std::vector clusters_; }; // Package void SetWaitMode(hwy::PoolWaitMode wait_mode) { @@ -195,6 +228,7 @@ class NestedPools { } } + PinningPolicy pinning_; bool all_pinned_; const char* pin_string_; diff --git a/util/threading_context.cc b/util/threading_context.cc index 90a64d1..8ffd4db 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -21,6 +21,7 @@ #include #include "hwy/aligned_allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" #include "hwy/tests/test_util.h" // RandomState @@ -28,7 +29,11 @@ namespace gcpp { // Invokes `pool.Run` with varying task counts until auto-tuning completes, or // an upper bound just in case. -static void TunePool(hwy::ThreadPool& pool) { +static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { + pool.SetWaitMode(wait_mode); + +// TODO(janwas): re-enable after investigating potential deadlock. +#if 0 const size_t num_workers = pool.NumWorkers(); // pool.Run would just be a serial loop without auto-tuning, so skip. if (num_workers == 1) return; @@ -69,6 +74,22 @@ static void TunePool(hwy::ThreadPool& pool) { HWY_ASSERT(total == prev_total + expected); prev_total += expected; } +#endif +} + +static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) { + TunePool(wait_mode, pools.AllPackages()); + for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) { + hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx); + TunePool(wait_mode, clusters); + + // Run in parallel because Turin CPUs have 16, and in real usage, we often + // run all at the same time. + clusters.Run(0, clusters.NumWorkers(), + [&](uint64_t cluster_idx, size_t /*thread*/) { + TunePool(wait_mode, pools.Cluster(pkg_idx, cluster_idx)); + }); + } } ThreadingContext::ThreadingContext(const ThreadingArgs& args) @@ -80,18 +101,9 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args) allocator(topology, cache_info, args.bind != Tristate::kFalse), pools(topology, allocator, args.max_threads, args.pin) { PROFILER_ZONE("Startup.ThreadingContext autotune"); - TunePool(pools.AllPackages()); - for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) { - hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx); - TunePool(clusters); - - // Run in parallel because Turin CPUs have 16, and in real usage, we often - // run all at the same time. - clusters.Run(0, clusters.NumWorkers(), - [&](uint64_t cluster_idx, size_t /*thread*/) { - TunePool(pools.Cluster(pkg_idx, cluster_idx)); - }); - } + TunePools(hwy::PoolWaitMode::kSpin, pools); + // kBlock is the default, hence set/tune it last. + TunePools(hwy::PoolWaitMode::kBlock, pools); } } // namespace gcpp diff --git a/util/threading_context.h b/util/threading_context.h index ac42526..ff4ff62 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -41,7 +41,7 @@ class ThreadingArgs : public ArgsBase { // For BoundedTopology: size_t skip_packages; - size_t max_packages = 1; + size_t max_packages = 1; // some users assign 1 to this, hence non-const. size_t skip_clusters; size_t max_clusters; size_t skip_lps; From fb6fa793f46e22249bf3c3c0bda36d11581eeef6 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 14 Oct 2025 08:30:23 -0700 Subject: [PATCH 60/65] Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle. Improved flash_attention to enable profiling using the new zones. PiperOrigin-RevId: 819235421 --- BUILD.bazel | 16 ++++++++ CMakeLists.txt | 2 + gemma/attention.cc | 11 ++++-- gemma/flash_attention.cc | 52 +++++++++++++++----------- gemma/flash_attention.h | 3 ++ gemma/flash_attention_test.cc | 11 +++++- gemma/gemma-inl.h | 7 ++-- gemma/gemma.cc | 10 ++--- gemma/weights.cc | 6 +-- ops/matmul-inl.h | 25 +++++++------ ops/ops-inl.h | 43 ++++++++------------- ops/ops_test.cc | 10 +++++ util/zones.cc | 70 +++++++++++++++++++++++++++++++++++ util/zones.h | 58 +++++++++++++++++++++++++++++ 14 files changed, 247 insertions(+), 77 deletions(-) create mode 100644 util/zones.cc create mode 100644 util/zones.h diff --git a/BUILD.bazel b/BUILD.bazel index f482e56..f5fad45 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -111,6 +111,7 @@ cc_library( ":basics", ":threading", ":topology", + ":zones", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:profiler", @@ -118,6 +119,15 @@ cc_library( ], ) +cc_library( + name = "zones", + srcs = ["util/zones.cc"], + hdrs = ["util/zones.h"], + deps = [ + "@highway//:profiler", + ], +) + cc_test( name = "flash_attention_test", srcs = ["gemma/flash_attention_test.cc"], @@ -263,6 +273,7 @@ cc_library( ":model_store", ":tensor_info", ":threading_context", + ":zones", "//compression:compress", "//io:blob_store", "@highway//:hwy", @@ -321,6 +332,7 @@ cc_library( ":matmul_env", ":threading", ":threading_context", + ":zones", "//compression:compress", "@highway//:bit_set", "@highway//:hwy", @@ -352,6 +364,7 @@ cc_library( ":matmul", ":matmul_env", ":threading_context", + ":zones", "//compression:compress", "//compression:types", "@highway//:hwy", @@ -376,6 +389,7 @@ cc_library( ":matmul_env", # MMOptions ":matmul_static", ":threading_context", + ":zones", "//compression:compress", "@highway//:algo", "@highway//:bit_set", @@ -431,6 +445,7 @@ cc_test( ":ops", ":test_util", ":threading_context", + ":zones", "@googletest//:gtest_main", # buildcleaner: keep "//compression:test_util", "//compression:types", @@ -556,6 +571,7 @@ cc_library( ":threading", ":threading_context", ":weights", + ":zones", "//compression:compress", "//compression:types", "//io", diff --git a/CMakeLists.txt b/CMakeLists.txt index 5dc4e11..983d643 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -130,6 +130,8 @@ set(SOURCES util/threading.h util/topology.cc util/topology.h + util/zones.cc + util/zones.h ) # Add C API sources only when building DLL diff --git a/gemma/attention.cc b/gemma/attention.cc index a77021a..8950bc2 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -19,6 +19,7 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -55,8 +56,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const float* HWY_RESTRICT q, const MatPtrT& k, float* HWY_RESTRICT att, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Gen.Attention.QDotK"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK)); if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { @@ -175,7 +175,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, ThreadingContext& ctx) { - static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par"); + static const auto root_zone = + ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive", + hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(ctx.profiler, 0, root_zone); + const auto zone = + GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar); const hwy::Divisor div_qbatch(qbatch.Size()); const LayerConfig& layer_config = layer.layer_config; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 548c1aa..cfadf28 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -22,6 +22,7 @@ #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "util/threading_context.h" +#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -60,7 +61,7 @@ static constexpr size_t kNFx8HTileSize = 8; // possible consecutive elements have the same KV. static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, const size_t qbatch_size, ThreadingContext& ctx) { - static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); + const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ); // Group floats by the number of floats in a cache line. const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t num_heads = q.Cols() / q_t.Rows(); @@ -95,8 +96,8 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, const LayerWeightsPtrs& layer, const AttentionActivations& activations, ThreadingContext& ctx) { - static const auto zone = - ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); + const auto zone = + GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding); const float query_scale = activations.query_scale; const hwy::Divisor div_qbatch(qbatch.Size()); const auto func = [&](const size_t task, size_t worker) HWY_ATTR { @@ -158,8 +159,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const AttentionActivations& activations, float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, + GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention)); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); float m = Dot(q, k.Row(pos_mod), k.Cols()); if (float cap = activations.config.att_cap; cap > 0.0f) { @@ -276,8 +277,8 @@ void TileFlashAttention( const LayerWeightsPtrs& layer, const AttentionActivations& activations, MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, + GetProfilerZone(Zones::kFlashAttentionTileFlashAttention)); constexpr int kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; const DF df; @@ -430,8 +431,8 @@ void TileFlashAttention4( const LayerWeightsPtrs& layer, const AttentionActivations& activations, MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, + GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4)); using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; @@ -524,6 +525,21 @@ static size_t RoundToSuitablePowerOf2(size_t n) { return 32; } +// The vertical tile size is determined by the ability to use tiling and the +// target_parallelism. In practice the possible tile sizes in order of +// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or +// 16. The final tile size is chosen to be the largest possible that allows +// for target_parallelism parallel tasks. +size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, + size_t total_tasks, size_t target_parallelism) { + const size_t kMaxEqualK = + RoundToSuitablePowerOf2(num_head_groups * num_tokens); + const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1; + return (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) + ? kNF + : std::min(kMinTileSize, kMaxEqualK); +} + // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] // into a single output O[L,D]. // Conventional attention first computes A[L,L] = Q . KT @@ -582,7 +598,10 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, ThreadingContext& ctx) { - static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); + static const auto root_zone = ctx.profiler.AddZone( + "FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(ctx.profiler, 0, root_zone); + const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention); RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, layer, activations, ctx); const hwy::Divisor div_qbatch(qbatch.Size()); @@ -603,17 +622,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, const size_t kNF = hn::Lanes(df); constexpr size_t kMaxNF = hn::MaxLanes(df); HWY_DASSERT(kNF <= kMaxNF); - // The vertical tile size is determined by the ability to use tiling and the - // target_parallelism. In practice the possible tile sizes in order of - // preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or - // 16. The final tile size is chosen to be the largest possible that allows - // for target_parallelism parallel tasks. - const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens); - const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1; - const size_t kVTileSize = - (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) - ? kNF - : std::min(kMinTileSize, kMaxEqualK); + const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, + total_tasks, target_parallelism); // Only transpose Q if we are using tiling. if (kVTileSize == kNF) { size_t max_last = 0, min_start = std::numeric_limits::max(); diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 75e087a..8aa787b 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -42,6 +42,9 @@ namespace gcpp { float* HWY_RESTRICT att_out, hwy::Profiler& p, \ size_t worker); \ \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ AttentionActivations& activations, QBatch& qbatch, \ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 7f8f31e..4147e38 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -101,7 +101,6 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { void TestFlashAttention(size_t target_parallelism) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - // hwy::ThreadPool& pool = ctx.pools.Pool(); constexpr size_t kOuter = 1024; constexpr size_t kInner = 256; ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); @@ -150,9 +149,19 @@ void TestFlashAttention(size_t target_parallelism) { // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + const size_t total_tasks = + tokens.size() * div_qbatch.GetDivisor() * layer_config.heads; + const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(), + total_tasks, target_parallelism); + printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", + target_parallelism, kNF, kVTileSize); FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, qbatch, ctx); AssertClose(attention.att_out, *saved_att); + ctx.profiler.PrintResults(); } void TestAttention() { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index bdf989a..ecfbe47 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -24,6 +24,7 @@ #include "ops/matmul.h" #include "util/mat.h" #include "util/threading.h" +#include "util/zones.h" #include "hwy/profiler.h" // Include guard (still compiled once per target) @@ -48,8 +49,7 @@ template void Activation(ActivationType activation, T1* HWY_RESTRICT c1, const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Gen.Activation"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -88,8 +88,7 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1, const IndexRange range_r, const IndexRange range_c, const StridedViewBF C2, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Gen.ActivationFused"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused)); const size_t cols = range_c.Num(); HWY_DASSERT(C2.Cols() == cols); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index c3e2bac..78c9cc4 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -19,6 +19,7 @@ #include "gemma/gemma.h" #include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -466,14 +467,12 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; - static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1"); - static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general"); - // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE3(ctx.profiler, worker, zone_top1); + PROFILER_ZONE3(ctx.profiler, worker, + GetProfilerZone(Zones::kGenSampleTop1)); return Top1OfSoftmax(logits); }; } @@ -481,7 +480,8 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, // General case: Softmax with top-k sampling. return [&](size_t qi, size_t pos, Logits logits, size_t worker) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE3(ctx.profiler, worker, zone_topK); + PROFILER_ZONE3(ctx.profiler, worker, + GetProfilerZone(Zones::kGenSampleTopK)); // We want a different sequence for each batch element and position. const uint64_t stream = (static_cast(qi) << 32) | pos; RngStream gen(engine, stream); diff --git a/gemma/weights.cc b/gemma/weights.cc index fb59297..cd8875b 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -32,6 +32,7 @@ #include "io/blob_store.h" #include "util/mat.h" #include "util/threading_context.h" +#include "util/zones.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -379,8 +380,7 @@ static void DecompressToBF16(MatPtr& mat, static void ReadAllToBF16(const std::vector& tensors, const BlobReader& reader, ThreadingContext& ctx) { - static const auto zone = - ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16"); + const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16); // Especially TSAN is slow enough to warrant hierarchical parallelism. const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD ? ParallelismStrategy::kHierarchical @@ -463,7 +463,7 @@ static std::vector MakeBatches( static void ReadBatches(const BlobReader& reader, const std::vector& batches, ThreadingContext& ctx) { - static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches"); + const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches); // >5x speedup from parallel reads when cached. ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx, /*cluster_idx=*/0, diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 8957f4c..d72ac38 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -20,11 +20,12 @@ #include #include "compression/types.h" -#include "ops/matmul.h" // IWYU pragma: export +#include "ops/matmul.h" // IWYU pragma: export #include "util/allocator.h" // CacheInfo #include "util/basics.h" #include "util/mat.h" #include "util/threading_context.h" +#include "util/zones.h" #include "hwy/base.h" #include "hwy/profiler.h" #include "hwy/timer.h" @@ -290,7 +291,7 @@ class MMDecompress { const hn::ScalableTag dbf; const size_t NBF = hn::Lanes(dbf); - static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); + const auto zone = GetProfilerZone(Zones::kMMDecompressA); const auto do_range = [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) @@ -878,9 +879,9 @@ class MMLoops { static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, - args.env.ctx.Worker(args.options.cluster_idx), zone); + args.env.ctx.Worker(args.options.cluster_idx), + GetProfilerZone(Zones::kMMDispatch)); DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { @@ -903,7 +904,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); + const auto zone = GetProfilerZone(Zones::kMMNT); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -939,7 +940,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); + const auto zone = GetProfilerZone(Zones::kMMNT_K); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -975,7 +976,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); + const auto zone = GetProfilerZone(Zones::kMMNT_MT); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); const IndexRange& range_kc = args.ranges_kc.Range(0); @@ -1009,7 +1010,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); + const auto zone = GetProfilerZone(Zones::kMMNT_MT_K); parallel.ForRangesMC_NC( args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, @@ -1060,10 +1061,10 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { - static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); const size_t cluster_idx = options.cluster_idx; HWY_DASSERT(cluster_idx < env.row_ptrs.size()); - PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); + PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), + GetProfilerZone(Zones::kMMMatMul)); RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); @@ -1121,10 +1122,10 @@ template HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT& A, const MatPtrT& B1, const MatPtrT& B2, MatMulEnv& env, MatPtrT& C, MMOptions options) { - static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul"); const size_t cluster_idx = options.cluster_idx; HWY_DASSERT(cluster_idx < env.row_ptrs.size()); - PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); + PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), + GetProfilerZone(Zones::kMMTwoMatMul)); HWY_DASSERT(options.func != nullptr); // no other way to get access to C2. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index a52c788..162b48a 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -32,6 +32,7 @@ #include "util/basics.h" // TokenAndProb, RngStream #include "util/mat.h" #include "util/threading_context.h" +#include "util/zones.h" #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/contrib/sort/order.h" @@ -206,8 +207,7 @@ namespace detail { template float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.RMSNormMul"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul)); const hn::ScalableTag d; const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); @@ -223,8 +223,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.RMSNorm"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -248,8 +247,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.RMSNormInplace"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -365,8 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, const size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.Rope"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope)); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -425,8 +422,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.RopeAndMulBy"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy)); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -488,8 +484,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.AddFrom"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -568,8 +563,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.MulByConst"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -587,8 +581,7 @@ template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.MulByConstTo"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -606,8 +599,7 @@ template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -744,8 +736,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile)); namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1007,8 +998,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4)); namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1049,8 +1039,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( const size_t pos, float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, const size_t size, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector)); namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1146,8 +1135,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, const size_t worker, float temperature = 1.0f) { - static const auto zone = p.AddZone("Ops.Softmax"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax)); HWY_DASSERT(logits.size() != 0); namespace hn = hwy::HWY_NAMESPACE; @@ -1280,8 +1268,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) { static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, hwy::Profiler& p, const size_t worker) { - static const auto zone = p.AddZone("Ops.LogitsSoftCap"); - PROFILER_ZONE3(p, worker, zone); + PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap)); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 213fdd0..40f1002 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -14,6 +14,7 @@ // limitations under the License. #include "compression/types.h" +#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -132,6 +133,7 @@ class TestAddFrom { } SimpleAddFrom(o, e, count); + InitProfilerZones(hwy::Profiler::Get()); AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -180,6 +182,7 @@ class TestMulByConstAndAdd { T constant = Random(rng); SimpleMulByConstAndAdd(constant, o, e, count); + InitProfilerZones(hwy::Profiler::Get()); MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -228,6 +231,7 @@ class TestMulByConst { T constant = Random(rng); SimpleMulByConst(constant, e, count); + InitProfilerZones(hwy::Profiler::Get()); MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -274,6 +278,7 @@ struct TestMulByConstTo { hwy::ConvertScalarTo(constant)); } + InitProfilerZones(hwy::Profiler::Get()); MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(), /*worker=*/0); @@ -310,6 +315,7 @@ class TestSoftmax { } SimpleSoftmax(e, count); + InitProfilerZones(hwy::Profiler::Get()); Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0); T sum = 0.0f; @@ -437,6 +443,7 @@ void TestRopeAndMulBy() { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); hwy::Profiler& p = ctx.profiler; + InitProfilerZones(p); const size_t worker = 0; const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, @@ -551,6 +558,7 @@ struct TestRMSNorm { } ScalarRMSNorm(vec, weight, expected, kSize); + InitProfilerZones(hwy::Profiler::Get()); RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); for (size_t i = 0; i < kSize; i++) { @@ -585,6 +593,7 @@ struct TestRMSNormInplace { } ScalarRMSNorm(expected, weight, expected, kSize); + InitProfilerZones(hwy::Profiler::Get()); RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); @@ -707,6 +716,7 @@ void TestAllLayerNorm() { void TestSampleTopK() { hwy::Profiler& p = hwy::Profiler::Get(); + InitProfilerZones(p); const size_t worker = 0; const size_t kSize = 52; std::vector logits_vec(kSize); diff --git a/util/zones.cc b/util/zones.cc new file mode 100644 index 0000000..abc9dc2 --- /dev/null +++ b/util/zones.cc @@ -0,0 +1,70 @@ +#include "util/zones.h" + +#include "hwy/profiler.h" + +namespace gcpp { + +#if PROFILER_ENABLED +static constexpr size_t kNumZones = static_cast(Zones::kNumZones); + +static const char* kProfilerZoneNames[kNumZones] = { + // Keep in sync with Zones enum. + "Ops.RMSNormMul", + "Ops.RMSNorm", + "Ops.RMSNormInplace", + "Ops.Rope", + "Ops.RopeAndMulBy", + "Ops.AddFrom", + "Ops.MulByConst", + "Ops.MulByConstTo", + "Ops.MulByConstAndAdd", + "Ops.MulByConstAndAddTile", + "Ops.MulByConstAndAddTile4", + "Ops.MulByConstAndAddVector", + "Ops.Softmax", + "Ops.LogitsSoftCap", + "FlashAttention.TransposeQ", + "FlashAttention.RMSNormAndPositionalEncoding", + "FlashAttention.SingleFlashAttention", + "FlashAttention.TileFlashAttention", + "FlashAttention.TileFlashAttention4", + "FlashAttention.FlashAttention", + "Gen.Activation", + "Gen.ActivationFused", + "Gen.SampleTop1", + "Gen.SampleTopK", + "Gen.Attention.QDotK", + "Gen.Attention.DotSoftmaxWeightedSum.par", + "Startup.Weights.ReadAllToBF16", + "Startup.Weights.ReadBatches", + "MM.Dispatch", + "MM.MatMul", + "MM.TwoMatMul", + "MM.DecompressA", + "MM.NT", + "MM.NT_K", + "MM.NT_MT", + "MM.NT_MT_K", +}; + +static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones]; +#endif + +void InitProfilerZones(hwy::Profiler& profiler) { +#if PROFILER_ENABLED + // Initialize the zone handles. This is done once at startup. + for (size_t i = 0; i < kNumZones; ++i) { + profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]); + } +#endif +} + +hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) { +#if PROFILER_ENABLED + return profiler_zone_handles[static_cast(zone)]; +#else + return hwy::profiler::ZoneHandle(); +#endif +} + +} // namespace gcpp diff --git a/util/zones.h b/util/zones.h new file mode 100644 index 0000000..e78340a --- /dev/null +++ b/util/zones.h @@ -0,0 +1,58 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ + +#include "hwy/profiler.h" + +namespace gcpp { + +// Zones for the profiler. +enum class Zones { + kOpsRmsNormMul, + kOpsRmsNorm, + kOpsRmsNormInplace, + kOpsRope, + kOpsRopeAndMulBy, + kOpsAddFrom, + kOpsMulByConst, + kOpsMulByConstTo, + kOpsMulByConstAndAdd, + kOpsMulByConstAndAddTile, + kOpsMulByConstAndAddTile4, + kOpsMulByConstAndAddVector, + kOpsSoftmax, + kOpsLogitsSoftCap, + kFlashAttentionTransposeQ, + kFlashAttentionRmsNormAndPositionalEncoding, + kFlashAttentionSingleFlashAttention, + kFlashAttentionTileFlashAttention, + kFlashAttentionTileFlashAttention4, + kFlashAttentionFlashAttention, + kGenActivation, + kGenActivationFused, + kGenSampleTop1, + kGenSampleTopK, + kGenAttentionQDotK, + kGenAttentionDotSoftmaxWeightedSumPar, + kStartupWeightsReadAllToBF16, + kStartupWeightsReadBatches, + kMMDispatch, + kMMMatMul, + kMMTwoMatMul, + kMMDecompressA, + kMMNT, + kMMNT_K, + kMMNT_MT, + kMMNT_MT_K, + kNumZones +}; + +// Initializes the profiler zones. Must be called before any other profiler +// functions. +void InitProfilerZones(hwy::Profiler& profiler); + +// Returns the zone handle for the given zone enum value. +hwy::profiler::ZoneHandle GetProfilerZone(Zones zone); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ From e3e8511e794a944fb0edabb6dfd22088d682e365 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 15 Oct 2025 03:05:30 -0700 Subject: [PATCH 61/65] Initialization of profiler zones. PiperOrigin-RevId: 819662587 --- util/threading_context.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/util/threading_context.cc b/util/threading_context.cc index 8ffd4db..e2c4d03 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -20,6 +20,7 @@ #include +#include "util/zones.h" #include "hwy/aligned_allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" @@ -100,6 +101,7 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args) cache_info(topology), allocator(topology, cache_info, args.bind != Tristate::kFalse), pools(topology, allocator, args.max_threads, args.pin) { + InitProfilerZones(profiler); PROFILER_ZONE("Startup.ThreadingContext autotune"); TunePools(hwy::PoolWaitMode::kSpin, pools); // kBlock is the default, hence set/tune it last. From ee18916abffdcbc08fecb7c37a3d4bdc38a4bc80 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 15 Oct 2025 07:09:32 -0700 Subject: [PATCH 62/65] Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead. PiperOrigin-RevId: 819739402 --- gemma/attention.cc | 7 +++--- gemma/flash_attention.cc | 49 ++++++++++++++++++++-------------------- gemma/gemma.cc | 4 +--- gemma/vit.cc | 10 ++++---- ops/ops-inl.h | 27 ++++++++-------------- ops/ops_test.cc | 5 ++-- 6 files changed, 43 insertions(+), 59 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index 8950bc2..bf39702 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -89,7 +89,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx, // PostQKType::Rope if (post_qk == PostQKType::HalfRope) { Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker); - if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); } else { RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker); } @@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker); + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); } } else { { @@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, } for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = div_seq_len.Remainder(pos); - MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, - worker); + MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols()); } } } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index cfadf28..df6efd1 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } // Handles a single v row of flash attention for a single q.k dot product. -void HWY_INLINE SingleFlashAttentionStep( - float x, float cap, float& old_max, float& old_d, - const float* HWY_RESTRICT v, const size_t v_cols, - float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { +void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max, + float& old_d, + const float* HWY_RESTRICT v, + const size_t v_cols, + float* HWY_RESTRICT att_out) { if (cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. x = cap * std::tanh(x / cap); @@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep( float one_over_d = 1.0f / old_d; scale *= one_over_d; x *= one_over_d; - MulByConst(scale, att_out, v_cols, p, worker); - MulByConstAndAdd(x, v, att_out, v_cols, p, worker); + MulByConst(scale, att_out, v_cols); + MulByConstAndAdd(x, v, att_out, v_cols); } // Calculates the complete attention outputs for a single row of q. @@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const size_t pos_mod = activations.div_seq_len.Remainder(pos); float x = Dot(q, k.Row(pos_mod), k.Cols()); SingleFlashAttentionStep(x, activations.config.att_cap, m, d, - v.Row(pos_mod), v.Cols(), att_out, p, worker); + v.Row(pos_mod), v.Cols(), att_out); } } @@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, template > VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, const size_t k_pos, const MatPtrT& q, - const MatPtrT& k, hwy::Profiler& p, const size_t worker) { + const MatPtrT& k) { hn::TFromD results[hn::MaxLanes(df)]; for (size_t i = 0; i < hn::Lanes(df); ++i) { results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); @@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, // consecutive elements, and other columns by adding q_stride. template > void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const size_t* k_pos, - hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, + const MatPtrT& k, const size_t* k_pos, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); @@ -303,8 +303,8 @@ void TileFlashAttention( k_pos[i] = activations.div_seq_len.Remainder(position + i); } VF x0, x1, x2, x3, x4, x5, x6, x7; - QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3, - x4, x5, x6, x7); + QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, + x7); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -343,12 +343,12 @@ void TileFlashAttention( x6 = hn::Mul(x6, one_over_d); x7 = hn::Mul(x7, one_over_d); MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, - att_out.Row(0), out_offsets, v.Cols(), p, worker); + att_out.Row(0), out_offsets, v.Cols()); position += kHTileSize; } while (position <= max_last_pos) { size_t k_pos = activations.div_seq_len.Remainder(position); - VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker); + VF x0 = QDotKVector(df, q_offsets, k_pos, q, k); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. VF cap = hn::Set(df, activations.config.att_cap); @@ -369,7 +369,7 @@ void TileFlashAttention( x0 = hn::Mul(x0, one_over_d); scale = hn::Mul(scale, one_over_d); MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, - v.Cols(), p, worker); + v.Cols()); ++position; } } @@ -380,8 +380,8 @@ void TileFlashAttention( template > void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, - const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p, - const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { sum0 = hn::Zero(df); sum1 = hn::Zero(df); sum2 = hn::Zero(df); @@ -462,8 +462,7 @@ void TileFlashAttention4( k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); } VF x0, x1, x2, x3; - QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2, - x3); + QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -478,7 +477,7 @@ void TileFlashAttention4( scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), - out_offsets, v.Cols(), p, worker); + out_offsets, v.Cols()); position += kHTileSize; } while (position <= max_last_pos) { @@ -488,28 +487,28 @@ void TileFlashAttention4( float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[0], p, worker); + att_out.Row(0) + out_offsets[0]); } if (position <= last_pos[1]) { // Past the last position, x1 doesn't count. float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[1], p, worker); + att_out.Row(0) + out_offsets[1]); } if (position <= last_pos[2]) { // Past the last position, x2 doesn't count. float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[2], p, worker); + att_out.Row(0) + out_offsets[2]); } if (position <= last_pos[3]) { // Past the last position, x3 doesn't count. float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[3], p, worker); + att_out.Row(0) + out_offsets[3]); } ++position; } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 78c9cc4..80bf9e2 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -160,7 +160,6 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const size_t model_dim = model_config.model_dim; const float emb_scaling = EmbeddingScaling(model_dim); - const size_t worker = 0; // Not yet parallelized. HWY_DASSERT(token >= 0); HWY_DASSERT(token < static_cast(model_config.vocab_size)); @@ -176,8 +175,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const hn::ScalableTag df; DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim, - ctx.profiler, worker); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim); }); if (model_config.absolute_pe) { diff --git a/gemma/vit.cc b/gemma/vit.cc index 44b1bcb..d21be16 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -95,7 +95,7 @@ class VitAttention { float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); + MulByConst(query_scale, q, qkv_dim); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); }); @@ -120,8 +120,7 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, - env_.ctx.profiler, worker); + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); } }); } @@ -144,7 +143,7 @@ class VitAttention { // Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); + MulByConst(query_scale, q, qkv_dim); float* HWY_RESTRICT head_att = activations_.attention.att.Row(token) + head * seq_len; for (size_t i = 0; i < seq_len; ++i) { @@ -161,8 +160,7 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, - env_.ctx.profiler, worker); + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); } }); } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 162b48a..c966a68 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -560,10 +560,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, - const size_t size, - hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst)); + const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -596,10 +593,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( // out[i] += x[i] * c. template -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( - const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd)); +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c, + const XT* HWY_RESTRICT x, + OT* HWY_RESTRICT out, + const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -734,9 +731,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT& v, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile)); + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -996,9 +991,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, const VF c2, const VF c3, const MatPtrT& v, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4)); + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1037,9 +1030,7 @@ template > HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( DF df, const VF scale, const VF c0, const MatPtrT& v, const size_t pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector)); + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1177,7 +1168,7 @@ static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, const float sum_exp = Sum(d, logits.data(), logits.size()); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, logits.data(), logits.size(), p, worker); + MulByConst(mul, logits.data(), logits.size()); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 40f1002..dd8e4e8 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -183,7 +183,7 @@ class TestMulByConstAndAdd { SimpleMulByConstAndAdd(constant, o, e, count); InitProfilerZones(hwy::Profiler::Get()); - MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0); + MulByConstAndAdd(constant, o, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -232,7 +232,7 @@ class TestMulByConst { SimpleMulByConst(constant, e, count); InitProfilerZones(hwy::Profiler::Get()); - MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0); + MulByConst(constant, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -443,7 +443,6 @@ void TestRopeAndMulBy() { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); hwy::Profiler& p = ctx.profiler; - InitProfilerZones(p); const size_t worker = 0; const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, From 503aaddd65619cd9dc5ea24583a7dd00cdf75c26 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Wed, 15 Oct 2025 09:24:38 -0700 Subject: [PATCH 63/65] Add 8-bit integer quantization (I8Stream) to Gemma.cpp. PiperOrigin-RevId: 819787856 --- BUILD.bazel | 1 + CMakeLists.txt | 2 + compression/BUILD.bazel | 33 ++ compression/compress-inl.h | 44 +- compression/compress_test.cc | 8 +- compression/int-inl.h | 474 ++++++++++++++++++++ compression/int_test.cc | 494 +++++++++++++++++++++ compression/python/compression_clif_aux.cc | 3 + compression/python/compression_test.py | 12 + compression/types.h | 42 +- gemma/attention.cc | 8 +- gemma/flash_attention.cc | 4 +- gemma/gemma-inl.h | 7 +- gemma/model_store.cc | 2 + gemma/tensor_info.h | 2 +- gemma/vit.cc | 4 +- gemma/weights.cc | 248 ++++++++++- ops/matmul_static.h | 1 + ops/matmul_static_i8.cc | 29 ++ ops/ops-inl.h | 24 +- ops/ops_test.cc | 5 +- python/configs.cc | 6 +- util/basics.h | 19 + util/mat.cc | 8 +- util/mat.h | 12 +- 25 files changed, 1428 insertions(+), 64 deletions(-) create mode 100644 compression/int-inl.h create mode 100644 compression/int_test.cc create mode 100644 ops/matmul_static_i8.cc diff --git a/BUILD.bazel b/BUILD.bazel index f5fad45..ffd5435 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -349,6 +349,7 @@ cc_library( "ops/matmul_static_f32.cc", "ops/matmul_static_nuq.cc", "ops/matmul_static_sfp.cc", + "ops/matmul_static_i8.cc", ], hdrs = [ "ops/matmul_static.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 983d643..3eb2046 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ set(SOURCES compression/compress.h compression/nuq-inl.h compression/sfp-inl.h + compression/int-inl.h compression/types.h compression/test_util-inl.h evals/benchmark_helper.cc @@ -109,6 +110,7 @@ set(SOURCES ops/matmul_static_f32.cc ops/matmul_static_nuq.cc ops/matmul_static_sfp.cc + ops/matmul_static_i8.cc ops/matmul-inl.h ops/matmul.cc ops/matmul.h diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index a72db0b..c7232e6 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -80,6 +80,37 @@ cc_library( ], ) +cc_library( + name = "int", + textual_hdrs = ["int-inl.h"], + deps = [ + ":types", + "//:basics", + "@highway//:hwy", + ], +) + +cc_test( + name = "int_test", + size = "small", + timeout = "long", + srcs = ["int_test.cc"], + features = ["fully_static_link"], + linkstatic = True, + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":distortion", + ":int", + "@googletest//:gtest_main", # buildcleaner: keep + "//:test_util", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", + ], +) + cc_library( name = "test_util", textual_hdrs = [ @@ -144,6 +175,7 @@ cc_library( textual_hdrs = ["compress-inl.h"], deps = [ ":distortion", + ":int", ":nuq", ":sfp", "//:basics", @@ -182,6 +214,7 @@ cc_library( name = "analyze", textual_hdrs = ["analyze.h"], deps = [ + ":int", ":nuq", ":sfp", ":types", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 18d8e35..35f0433 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -47,6 +47,7 @@ #include "hwy/highway.h" // After highway.h +#include "compression/int-inl.h" #include "compression/nuq-inl.h" #include "compression/sfp-inl.h" @@ -416,6 +417,34 @@ struct CompressTraits { } }; +// Integer quantization. +template <> +struct CompressTraits { + using Packed = I8Stream; + + template + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, + size_t num, CompressPerThread& tls, + const PackedSpan& packed, + const size_t packed_ofs) { + IntCodec::Enc(df, raw, num, packed, packed_ofs); + } + + template // Caller checks this is f32 or bf16 + static HWY_INLINE void Load2(D d, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1); + } + + template + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, const size_t packed_ofs, + Raw* raw, const size_t num) { + IntCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num); + } +}; + // Nonuniform quantization, 4.5 bits per element, two separate streams. template <> struct CompressTraits { @@ -737,9 +766,10 @@ template HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1, + const size_t p1_ofs, Func&& func) { const auto packed_inout = MakeSpan(inout, num); - const auto packed1 = MakeSpan(p1, num); + const auto packed1 = MakeSpan(p1, p1_ofs + num); using VF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); @@ -749,7 +779,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, VF v0, v1; Decompress2(df, packed_inout, i, v0, v1); VF v10, v11; - Decompress2(df, packed1, i, v10, v11); + Decompress2(df, packed1, p1_ofs + i, v10, v11); const VF out0 = func(df, v0, v10); const VF out1 = func(df, v1, v11); Compress2(df, out0, out1, packed_inout, i); @@ -765,7 +795,7 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, hn::Store(hn::Zero(df), df, buf_inout + NF); hn::Store(hn::Zero(df), df, buf1 + NF); DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining); - DecompressAndZeroPad(df, packed1, i, buf1, remaining); + DecompressAndZeroPad(df, packed1, p1_ofs + i, buf1, remaining); const VF v0 = hn::Load(df, buf_inout); const VF v1 = hn::Load(df, buf_inout + NF); const VF v10 = hn::Load(df, buf1); @@ -827,10 +857,10 @@ template HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, const T1* HWY_RESTRICT p1, const T2* HWY_RESTRICT p2, - Func&& func) { + const size_t p2_ofs, Func&& func) { const auto packed_out = MakeSpan(out, num); const auto packed1 = MakeSpan(p1, num); - const auto packed2 = MakeSpan(p2, num); + const auto packed2 = MakeSpan(p2, p2_ofs + num); using VF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); @@ -839,7 +869,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, for (; i <= num - 2 * NF; i += 2 * NF) { VF v10, v11, v20, v21; Decompress2(df, packed1, i, v10, v11); - Decompress2(df, packed2, i, v20, v21); + Decompress2(df, packed2, p2_ofs + i, v20, v21); const VF out0 = func(df, v10, v20); const VF out1 = func(df, v11, v21); Compress2(df, out0, out1, packed_out, i); @@ -856,7 +886,7 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, hn::Store(hn::Zero(df), df, buf1 + NF); hn::Store(hn::Zero(df), df, buf2 + NF); DecompressAndZeroPad(df, packed1, i, buf1, remaining); - DecompressAndZeroPad(df, packed2, i, buf2, remaining); + DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining); const VF v10 = hn::Load(df, buf1); const VF v11 = hn::Load(df, buf1 + NF); const VF v20 = hn::Load(df, buf2); diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 5455b1d..2ee7f63 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -243,7 +243,7 @@ class TestDecompressAndCompress { // Uses `out` so as not to overwrite `p`. Decompress1AndCompressInplace( - df, out.get(), num, p1.get(), + df, out.get(), num, p1.get(), /*p1_ofs=*/0, [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); @@ -251,9 +251,9 @@ class TestDecompressAndCompress { [](DF, VF v) HWY_ATTR -> VF { return v; }); HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num); - Decompress2AndCompressTo(df, out.get(), num, p.get(), p1.get(), - [](DF, VF v, VF v1) - HWY_ATTR -> VF { return hn::Add(v, v1); }); + Decompress2AndCompressTo( + df, out.get(), num, p.get(), p1.get(), /*p2_ofs=*/0, + [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); Decompress3AndCompressTo( diff --git a/compression/int-inl.h b/compression/int-inl.h new file mode 100644 index 0000000..969ec6d --- /dev/null +++ b/compression/int-inl.h @@ -0,0 +1,474 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ + +#include +#include +#include + +#include +#include + +#include "compression/types.h" +#include "util/basics.h" +#include "hwy/base.h" +#include "hwy/print-inl.h" + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ + +// Actual per-target include guard. +#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_TOGGLE +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// Encode/decode functions. +class IntCodec { + using ScaleT = hwy::bfloat16_t; + static constexpr size_t kGroupSize = I8Stream::kGroupSize; + + // Offset (in bytes) of a group's start for packed_ofs (in elements) within a + // set of groups. + static constexpr size_t GroupByteOffset(size_t packed_ofs) { + const size_t kBytesPerGroup = (2 * sizeof(ScaleT)) + kGroupSize; + return (packed_ofs / kGroupSize) * kBytesPerGroup; + } + + public: + template + static HWY_INLINE void DequantizeGroup( + DBF dbf, const PackedSpan& packed, size_t packed_ofs, + hwy::bfloat16_t* HWY_RESTRICT raw, size_t num) { + using T = ScaleT; + const hn::ScalableTag df; + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + const hn::Twice> dbf16; + + const size_t N = hn::Lanes(di8); + const size_t N16 = hn::Lanes(dbf16); + using VI8 = hn::Vec; + using VF = hn::Vec; + + T inv_scale, zeropoint; + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, + sizeof(T)); + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), + &zeropoint, sizeof(T)); + + float inv_scale_f = hwy::ConvertScalarTo(inv_scale); + float zeropoint_f = hwy::ConvertScalarTo(zeropoint); + + VF inv_scale_vec = hn::Set(df, inv_scale_f); + VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f)); + + // Then iterate over remainder of packed, extracting num / N vectors and + // inserting into raw. + const size_t g_num = HWY_MIN(num, kGroupSize); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + size_t i = 0; + for (i = 0; i + 4 * N <= g_num; i += 4 * N) { + const VI8 val0 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N); + const VI8 val1 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N); + const VI8 val2 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 2 * N); + const VI8 val3 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 3 * N); + + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + const VF val1_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1))); + const VF val2_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val2))); + const VF val3_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val3))); + + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec); + VF dequantized_val2 = hn::MulAdd(inv_scale_vec, val2_f, zeroscale_vec); + VF dequantized_val3 = hn::MulAdd(inv_scale_vec, val3_f, zeroscale_vec); + + hn::StoreU( + hn::OrderedDemote2To(dbf16, dequantized_val0, dequantized_val1), + dbf16, raw + i + 0 * N16); + hn::StoreU( + hn::OrderedDemote2To(dbf16, dequantized_val2, dequantized_val3), + dbf16, raw + i + 1 * N16); + } + for (; i + N <= g_num; i += N) { + const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + const hn::Rebind dbf_half; + hn::StoreU(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i); + } + if (i < g_num) { + const size_t remaining = g_num - i; + const VI8 val0 = + hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + const hn::Rebind dbf_half; + hn::StoreN(hn::DemoteTo(dbf_half, dequantized_val0), dbf_half, raw + i, + remaining); + } + } + + // Dequantizes `num` floats from `packed` into `raw`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset + // within it, in number of elements. Scaling values are interleaved with int + // values to allow for easier unpacking. + template + static HWY_INLINE void DequantizeGroup( + DF df, const PackedSpan& packed, size_t packed_ofs, + float* HWY_RESTRICT raw, size_t num) { + using T = ScaleT; + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + const hn::Rebind df8; + + const size_t N = hn::Lanes(di8); + const size_t N32 = hn::Lanes(df); + using VI8 = hn::Vec; + using VF = hn::Vec; + + // HWY_ASSERT(num % 2 * N == 0); + + // Load scale and zero point from the beginning - ensure correct pointer + // offset. + T inv_scale, zeropoint; + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, + sizeof(T)); + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), + &zeropoint, sizeof(T)); + + float inv_scale_f = hwy::ConvertScalarTo(inv_scale); + float zeropoint_f = hwy::ConvertScalarTo(zeropoint); + + VF inv_scale_vec = hn::Set(df, inv_scale_f); + VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f)); + + // Then iterate over remainder of packed, extracting num / N vectors and + // inserting into raw. + const size_t g_num = HWY_MIN(num, kGroupSize); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + + size_t i = 0; + for (; i + 2 * N <= g_num; i += 2 * N) { + const VI8 val0 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 0 * N); + const VI8 val1 = + hn::LoadU(di8, &packed.ptr->i + current_offset + i + 1 * N); + + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + const VF val1_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1))); + + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + VF dequantized_val1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec); + + hn::StoreU(dequantized_val0, df, raw + i + 0 * N32); + hn::StoreU(dequantized_val1, df, raw + i + 1 * N32); + } + for (; i + N <= g_num; i += N) { + const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + i); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + hn::StoreU(dequantized_val0, df, raw + i); + } + if (i < g_num) { + const size_t remaining = g_num - i; + const VI8 val0 = + hn::LoadN(di8, &packed.ptr->i + current_offset + i, remaining); + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + VF dequantized_val0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + hn::StoreN(dequantized_val0, df, raw + i, remaining); + } + } + + // Quantizes `num` floats from `raw` into `packed`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset + // within it, in number of elements. Scaling values are interleaved with + // int values to allow for easier unpacking. + template + static HWY_INLINE void QuantizeGroup(DF df, const float* HWY_RESTRICT raw, + size_t num, + const PackedSpan& packed, + size_t packed_ofs) { + using T = ScaleT; + const hn::Repartition di32; + const hn::Half> di16; + const hn::Half> di8; + + const size_t N = hn::Lanes(df); + using VI8 = hn::Vec; + using VF = hn::Vec; + + HWY_DASSERT(packed_ofs % kGroupSize == 0); + HWY_DASSERT(num % 2 * N == 0); + + // Calculate min/max using SIMD + float min_val = hwy::HighestValue(); + float max_val = hwy::LowestValue(); + VF vmin = hn::Set(df, hwy::HighestValue()); + VF vmax = hn::Set(df, hwy::LowestValue()); + + size_t j = 0; + for (; j + N <= num; j += N) { + const VF xi = hn::LoadU(df, raw + j); + vmin = hn::Min(vmin, xi); + vmax = hn::Max(vmax, xi); + } + + min_val = hn::ReduceMin(df, vmin); + max_val = hn::ReduceMax(df, vmax); + + for (; j < num; ++j) { + min_val = HWY_MIN(min_val, raw[j]); + max_val = HWY_MAX(max_val, raw[j]); + } + + // Calculate range, scale and zeropoint + float x_range = max_val - min_val; + x_range = x_range == 0.0f ? 1.0f : x_range; + const float scale_f = 255.0f / x_range; + const float zeropoint_f = static_cast( + static_cast(-scale_f * min_val - 128.0f)); // Correct casting + + const T scale = hwy::ConvertScalarTo(scale_f); + // inv_scale is used for all dequantization. + const T inv_scale = hwy::ConvertScalarTo(1.0f / scale_f); + const T zeropoint = hwy::ConvertScalarTo(zeropoint_f); + memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, sizeof(T)); + memcpy(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), &zeropoint, + sizeof(T)); + + const size_t g_num = HWY_MIN(num, kGroupSize); + + VF mul = hn::Set(df, hwy::ConvertScalarTo(scale)); + VF add = hn::Set(df, hwy::ConvertScalarTo(zeropoint)); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + + size_t i = 0; + for (; i + 2 * N <= g_num; i += 2 * N) { + const VI8 val0 = hn::DemoteTo( + di8, + hn::DemoteTo(di16, NearestInt(hn::MulAdd( + mul, hn::LoadU(df, raw + i + 0 * N), add)))); + const VI8 val1 = hn::DemoteTo( + di8, + hn::DemoteTo(di16, NearestInt(hn::MulAdd( + mul, hn::LoadU(df, raw + i + 1 * N), add)))); + + hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N); + hn::StoreU(val1, di8, &packed.ptr->i + current_offset + i + 1 * N); + } + + size_t remaining = g_num - i; + + HWY_DASSERT(remaining < 2 * N); + if (HWY_UNLIKELY(remaining == 0)) return; + + if (remaining > N) { + const VI8 val0 = hn::DemoteTo( + di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd( + mul, hn::LoadU(df, raw + i), add)))); + hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i); + + const size_t remaining1 = remaining - N; + const VI8 val1 = hn::DemoteTo( + di8, + hn::DemoteTo(di16, + NearestInt(hn::MulAdd( + mul, hn::LoadN(df, raw + i + N, remaining1), add)))); + hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N, + remaining1); + } else { // remaining <= N + const VI8 val0 = hn::DemoteTo( + di8, hn::DemoteTo(di16, + NearestInt(hn::MulAdd( + mul, hn::LoadN(df, raw + i, remaining), add)))); + hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining); + } + } + + // Encodes `num` floats from `raw` into `packed`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset + // within it, in number of elements. Scaling values are interleaved with + // int + // values to allow for easier unpacking. + template + static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT raw, + const size_t num, + const PackedSpan& packed, + size_t packed_ofs) { + HWY_ASSERT(packed_ofs % kGroupSize == 0); + + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + + size_t current_offset = packed_ofs; + for (size_t g = 0; g < num_groups; ++g) { + const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); + const float* HWY_RESTRICT g_in = raw + g * kGroupSize; + + QuantizeGroup(df, g_in, g_num, packed, current_offset); + current_offset += g_num; + } + } + + // Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. + template + static HWY_INLINE void Dec2(DBF dbf, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + HWY_ASSERT(packed_ofs % 2 * NF == 0); + + VF raw0_f, raw1_f, raw2_f, raw3_f; + Dec2(df, packed, packed_ofs + 0 * 2 * NF, raw0_f, raw1_f); + Dec2(df, packed, packed_ofs + 1 * 2 * NF, raw2_f, raw3_f); + + raw0 = hn::OrderedDemote2To(dbf, raw0_f, raw1_f); + raw1 = hn::OrderedDemote2To(dbf, raw2_f, raw3_f); + } + + // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. + template + static HWY_INLINE void Dec2(DF df, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + using T = ScaleT; + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + const hn::Rebind df8; + + const size_t N = hn::Lanes(di8); + using VI8 = hn::Vec; + using VF = hn::Vec; + + T inv_scale, zeropoint; + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs), &inv_scale, + sizeof(T)); + hwy::CopyBytes(&packed.ptr->i + GroupByteOffset(packed_ofs) + sizeof(T), + &zeropoint, sizeof(T)); + + float inv_scale_f = hwy::ConvertScalarTo(inv_scale); + float zeropoint_f = hwy::ConvertScalarTo(zeropoint); + + VF inv_scale_vec = hn::Set(df, inv_scale_f); + VF zeroscale_vec = hn::Set(df, -zeropoint_f * (inv_scale_f)); + + const size_t current_offset = GroupByteOffset(packed_ofs) + + (2 * sizeof(T)) + (packed_ofs % kGroupSize); + + const VI8 val0 = hn::LoadU(di8, &packed.ptr->i + current_offset + 0 * N); + const VI8 val1 = hn::LoadU(di8, &packed.ptr->i + current_offset + 1 * N); + + const VF val0_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val0))); + const VF val1_f = + hn::ConvertTo(df, hn::PromoteTo(di32, hn::PromoteTo(di16, val1))); + + raw0 = hn::MulAdd(inv_scale_vec, val0_f, zeroscale_vec); + raw1 = hn::MulAdd(inv_scale_vec, val1_f, zeroscale_vec); + } + + template > + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, size_t packed_ofs, + Raw* HWY_RESTRICT raw, size_t num) { + if (num == 0) return; + + const size_t N = hn::Lanes(d); + const size_t padded_num = hwy::RoundUpTo(num, N); + if (padded_num > num) { + hwy::ZeroBytes(raw + num, (padded_num - num) * sizeof(Raw)); + } + + size_t current_packed_ofs = packed_ofs; + Raw* HWY_RESTRICT current_raw = raw; + size_t num_to_decompress = num; + + if (size_t within_group = current_packed_ofs % kGroupSize; + within_group != 0) { + const size_t remaining_in_group = kGroupSize - within_group; + const size_t num_in_first_group = + HWY_MIN(num_to_decompress, remaining_in_group); + DequantizeGroup(d, packed, current_packed_ofs, current_raw, + num_in_first_group); + current_packed_ofs += num_in_first_group; + current_raw += num_in_first_group; + num_to_decompress -= num_in_first_group; + } + + if (num_to_decompress == 0) return; + + HWY_DASSERT(current_packed_ofs % kGroupSize == 0); + + const size_t num_full_groups = num_to_decompress / kGroupSize; + for (size_t g = 0; g < num_full_groups; ++g) { + DequantizeGroup(d, packed, current_packed_ofs, current_raw, kGroupSize); + current_packed_ofs += kGroupSize; + current_raw += kGroupSize; + } + + const size_t remaining = num_to_decompress % kGroupSize; + if (remaining != 0) { + DequantizeGroup(d, packed, current_packed_ofs, current_raw, remaining); + } + } +}; // IntCodec + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_INT_INL_H_ diff --git a/compression/int_test.cc b/compression/int_test.cc new file mode 100644 index 0000000..f427384 --- /dev/null +++ b/compression/int_test.cc @@ -0,0 +1,494 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) +#endif + +#include +#include +#include + +#include "util/test_util.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "compression/int_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/int-inl.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +static constexpr size_t kGroupSize = I8Stream::kGroupSize; +static constexpr float kTolerance = 50000.0f; + +// Can encode and decode sub-regions. +// Quantizes and de-quantizes a single (potentially partial) group to check +// that the quantizer is working correctly. +struct TestQuantize { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const size_t total = kGroupSize / 2; // already padded + const hn::ScalableTag df; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(total); + auto dec3 = hwy::AllocateAligned(total); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && dec3 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + IntCodec::QuantizeGroup(df, in.get(), total, int_span, 0); + + IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec1.get(), total); + + const float epsilon = + hwy::ConvertScalarTo(hwy::Epsilon()); + const float tolerance = kTolerance * epsilon; + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec1[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec1[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + } + } + + // Check that ::Enc works correctly as well. + IntCodec::Enc(df, in.get(), total, int_span, 0); + + IntCodec::DequantizeGroup(d, MakeConst(int_span), 0, dec2.get(), total); + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec2[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + } + } + + // Check that ::DecompressAndZeroPad works correctly for one group as well. + IntCodec::Enc(df, in.get(), total, int_span, 0); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec3.get(), + total); + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec3[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec3[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + HWY_ASSERT(false); + } + } + } +}; + +void TestQuantizeBF16() { hn::ForGEVectors<128, TestQuantize>()(BF16()); } +void TestQuantizeF32() { hn::ForGEVectors<128, TestQuantize>()(float()); } + +// Can encode and decode sub-regions. +// Quantizes and de-quantizes multiple (potentially partial) groups to check +// that DecompressAndZeroPad is working correctly. +struct TestMultiGroup { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = kGroupSize * 2 + kGroupSize / 4; // already padded + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(total); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + const float epsilon = + hwy::ConvertScalarTo(hwy::Epsilon()); + const float tolerance = kTolerance * epsilon; + + // Check that ::DecompressAndZeroPad works correctly for one group as well. + IntCodec::Enc(df, in.get(), total, int_span, 0); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec2.get(), + total); + + for (size_t i = 0; i < total; ++i) { + const float expected_value = static_cast(in[i]); + const float actual_value = hwy::ConvertScalarTo(dec2[i]); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, + "in[%zu] = %f, dec2[%zu] = %f, tolerance = %f, epsilon = %f\n", + i, expected_value, i, actual_value, tolerance, epsilon); + HWY_ASSERT(false); + } + } + } +}; + +void TestMultiGroupBF16() { hn::ForGEVectors<128, TestMultiGroup>()(BF16()); } +void TestMultiGroupF32() { hn::ForGEVectors<128, TestMultiGroup>()(float()); } + +// Can encode and decode sub-regions. +struct TestOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; // already padded + const size_t kMidLen = 2 * kGroupSize; // length of middle piece + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(kMidLen); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec1 && dec2 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Encode + decode everything + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(), + total); + + MaybeCheckInitialized(dec1.get(), total * sizeof(T)); + + // Overwrite middle with first inputs + const size_t offset = 5 * kGroupSize; + (void)IntCodec::Enc(df, in.get(), kMidLen, int_span, offset); + + // Decoded middle now matches previously decoded first + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, dec2.get(), + kMidLen); + MaybeCheckInitialized(dec2.get(), kMidLen * sizeof(T)); + + for (size_t i = 0; i < kMidLen; ++i) { + HWY_ASSERT(dec1[i] == dec2[i]); + } + } +}; + +void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); } +void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); } + +// Can encode and decode sub-regions. +struct TestUnalignedOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; // already padded + + const int num_unaligned_offsets = 4; + const std::array unaligned_offsets = { + 4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100}; + const std::array num = {4, 16, 32, 64}; + + for (int i = 0; i < num_unaligned_offsets; ++i) { + const size_t unaligned_offset = unaligned_offsets[i]; + const size_t num_decompressed = num[i]; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto i8_stream = + hwy::AllocateAligned(I8Stream::PackedEnd(total)); + auto dec2 = hwy::AllocateAligned(num_decompressed); + HWY_ASSERT(in && dec1 && dec2 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // // Encode + decode everything + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec1.get(), + total); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), unaligned_offset, + dec2.get(), num_decompressed); + + for (size_t i = 0; i < num_decompressed; ++i) { + T expected = hwy::ConvertScalarTo(dec1[unaligned_offset + i]); + T actual = hwy::ConvertScalarTo(dec2[i]); + + HWY_ASSERT_EQ(expected, actual); + } + } + } +}; + +void TestUnalignedOffsetBF16() { + hn::ForGEVectors<128, TestUnalignedOffset>()(BF16()); +} +void TestUnalignedOffsetF32() { + hn::ForGEVectors<128, TestUnalignedOffset>()(float()); +} + +// Can encode and decode sub-regions. +// Uses Dec2 to decode all elements in the packed buffer, then +// compares against DecompressAndZeroPad. +struct TestDec2 { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + // incl. partial group to test partial group handling + const size_t total = kGroupSize * 10 + kGroupSize / 2; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec0 = hwy::AllocateAligned(total); + auto dec1 = hwy::AllocateAligned(total); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(total)); + HWY_ASSERT(in && dec0 && dec1 && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Non-interleaved encode + decode for comparison + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, dec0.get(), + total); + + // Encode + decode everything + (void)IntCodec::Enc(df, in.get(), total, int_span, 0); + + using V = hn::Vec; + const size_t N = Lanes(d); + + for (size_t i = 0; i < total; i += 2 * N) { + V f0, f1; + IntCodec::Dec2(d, MakeConst(int_span), i, f0, f1); + + hn::StoreU(f0, d, dec1.get() + i + 0 * N); + hn::StoreU(f1, d, dec1.get() + i + 1 * N); + } + + for (size_t i = 0; i < total; ++i) { + if (dec0[i] != dec1[i]) { + fprintf(stderr, "dec0[%zu] = %g, dec1[%zu] = %g\n", i, + hwy::ConvertScalarTo(dec0[i]), i, + hwy::ConvertScalarTo(dec1[i])); + } + + HWY_ASSERT(dec0[i] == dec1[i]); + } + } +}; + +void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); } +void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); } + +// Tests that DecompressAndZeroPad fully populates the output array. +// This is intended to catch uninitialized value errors. +struct TestDequantizeAndZeroPad { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::ScalableTag df; + constexpr size_t kSize = 4096; + auto in = hwy::AllocateAligned(kSize); + auto actual_dec = hwy::AllocateAligned(kSize); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(kSize)); + HWY_ASSERT(in && actual_dec && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), kSize); + + // Fill with a known pattern. + for (size_t i = 0; i < kSize; ++i) { + in[i] = static_cast(i) - 128.0f; + } + + IntCodec::Enc(df, in.get(), kSize, int_span, 0); + + // Initialize with a sentinel value to detect if it's overwritten. + const T sentinel = hwy::ConvertScalarTo(-999.0f); + std::fill(actual_dec.get(), actual_dec.get() + kSize, sentinel); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), 0, actual_dec.get(), + kSize); + + MaybeCheckInitialized(actual_dec.get(), kSize * sizeof(T)); + + // Check that all sentinels were overwritten. + for (size_t i = 0; i < kSize; ++i) { + EXPECT_NE(hwy::ConvertScalarTo(actual_dec[i]), + hwy::ConvertScalarTo(sentinel)) + << " at index " << i; + } + } +}; + +void TestAllDequantizeAndZeroPad() { + hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(BF16()); + hn::ForGEVectors<128, TestDequantizeAndZeroPad>()(float()); +} + +// Tests that DecompressAndZeroPad works correctly for small and unaligned +// inputs. This is intended to catch uninitialized value errors in remainder +// handling. +struct TestSmallDequantize { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::ScalableTag df; + constexpr size_t kGroupSize = I8Stream::kGroupSize; + constexpr size_t kMaxNum = kGroupSize * 3; + auto in = hwy::AllocateAligned(kMaxNum); + auto actual_dec = hwy::AllocateAligned(kMaxNum); + auto i8_stream = + hwy::AllocateAligned(I8Stream::PackedEnd(kMaxNum)); + HWY_ASSERT(in && actual_dec && i8_stream); + const auto int_span = + MakeSpan(i8_stream.get(), I8Stream::PackedEnd(kMaxNum)); + + // Fill with a known pattern. + for (size_t i = 0; i < kMaxNum; ++i) { + in[i] = static_cast(i) - 128.0f; + } + + IntCodec::Enc(df, in.get(), kMaxNum, int_span, 0); + + for (size_t num = 1; num < kGroupSize * 2; ++num) { + for (size_t offset = 0; offset < kGroupSize; offset += 16) { + const T sentinel = hwy::ConvertScalarTo(-999.0f); + std::fill(actual_dec.get(), actual_dec.get() + num, sentinel); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, + actual_dec.get(), num); + + MaybeCheckInitialized(actual_dec.get(), num); + + // Check that all sentinels were overwritten. + for (size_t i = 0; i < num; ++i) { + EXPECT_NE(hwy::ConvertScalarTo(actual_dec[i]), + hwy::ConvertScalarTo(sentinel)) + << " at index " << i << " for num=" << num + << " offset=" << offset; + } + } + } + } +}; + +void TestAllSmallDequantize() { + hn::ForGEVectors<128, TestSmallDequantize>()(BF16()); + hn::ForGEVectors<128, TestSmallDequantize>()(float()); +} + +// Tests that DecompressAndZeroPad works correctly for a specific failing input. +struct TestSpecificDequantize { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::ScalableTag df; + constexpr size_t kSize = 737280; + auto in = hwy::AllocateAligned(kSize); + auto actual_dec = hwy::AllocateAligned(kSize); + auto i8_stream = hwy::AllocateAligned(I8Stream::PackedEnd(kSize)); + HWY_ASSERT(in && actual_dec && i8_stream); + const auto int_span = MakeSpan(i8_stream.get(), kSize); + + // Fill with a known pattern. + for (size_t i = 0; i < kSize; ++i) { + in[i] = static_cast(i) - 128.0f; + } + + IntCodec::Enc(df, in.get(), kSize, int_span, 0); + + const size_t num = 64; + const size_t offset = 392704; + const T sentinel = hwy::ConvertScalarTo(-999.0f); + std::fill(actual_dec.get(), actual_dec.get() + num, sentinel); + + IntCodec::DecompressAndZeroPad(d, MakeConst(int_span), offset, + actual_dec.get(), num); + + // Check that all sentinels were overwritten. + for (size_t i = 0; i < num; ++i) { + EXPECT_NE(hwy::ConvertScalarTo(actual_dec[i]), + hwy::ConvertScalarTo(sentinel)) + << " at index " << i << " for num=" << num << " offset=" << offset; + } + } +}; + +void TestAllSpecificDequantize() { + hn::ForGEVectors<128, TestSpecificDequantize>()(BF16()); + hn::ForGEVectors<128, TestSpecificDequantize>()(float()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_BEFORE_TEST(IntTest); +HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestOffsetBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestQuantizeBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestDec2BF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestDec2F32); +HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestMultiGroupBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetBF16); +HWY_EXPORT_AND_TEST_P(IntTest, TestUnalignedOffsetF32); +HWY_EXPORT_AND_TEST_P(IntTest, TestAllDequantizeAndZeroPad); +HWY_EXPORT_AND_TEST_P(IntTest, TestAllSmallDequantize); +HWY_EXPORT_AND_TEST_P(IntTest, TestAllSpecificDequantize); +HWY_AFTER_TEST(); +} // namespace gcpp +#endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 5e729cc..5f227ac 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -113,6 +113,9 @@ class SbsWriterImpl : public ISbsWriter { case Type::kF32: InsertT(name, weights, tensor_info); break; + case Type::kI8: + InsertT(name, weights, tensor_info); + break; default: HWY_ABORT("Unsupported destination (compressed) type %s", TypeName(type)); diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 957f0ec..16e6bf9 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -90,6 +90,13 @@ class CompressionTest(absltest.TestCase): info_256, ) + writer.insert( + "tensor_i8", + np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32), + configs.Type.kI8, + info_256, + ) + config = configs.ModelConfig( configs.Model.GEMMA2_2B, configs.Type.kSFP, @@ -140,6 +147,11 @@ class CompressionTest(absltest.TestCase): self.assertEqual(mat.type, configs.Type.kF32) self.assertAlmostEqual(mat.scale, 1.0) + mat = reader.find_mat("tensor_i8") + self.assertEqual(mat.cols, 256) + self.assertEqual(mat.rows, 1) + self.assertEqual(mat.type, configs.Type.kI8) + self.assertAlmostEqual(mat.scale, 1.0) if __name__ == "__main__": absltest.main() diff --git a/compression/types.h b/compression/types.h index c3be52a..8f11591 100644 --- a/compression/types.h +++ b/compression/types.h @@ -89,6 +89,26 @@ struct SfpStream { }; #pragma pack(pop) +#pragma pack(push, 1) +struct I8Stream { + static constexpr size_t kGroupSize = 128; + using ScaleT = hwy::bfloat16_t; + + // Returns number of I8Stream to allocate for the stream, which matches its + // size in bytes. + // TODO: should support other types beyond hwy::float32_t for scale and + // zero-point. + static constexpr size_t PackedEnd(size_t capacity) { + const size_t num_groups = hwy::DivCeil(capacity, kGroupSize); + return (sizeof(ScaleT) * num_groups) + // scale + (sizeof(ScaleT) * num_groups) + // zero-point + capacity; // 1 value per byte + } + + int8_t i; +}; +#pragma pack(pop) + // Non-uniform quantization: a compressed representation of f32 inputs that // supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or // two vectors (for `Decompress2`), and decoding to bf16/f32. @@ -187,18 +207,23 @@ constexpr bool IsNuqStream() { return hwy::IsSame, NuqStream>(); } +template +constexpr bool IsI8Stream() { + return hwy::IsSame, I8Stream>(); +} + template constexpr bool SupportsPointerArithmetic() { - return !IsNuqStream(); + return !IsNuqStream() && !IsI8Stream(); } // Tensor types for loading weights. Not all of these are supported weight // types, some are only used for `Activations`. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 }; +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "u32", "u64"}; +static constexpr const char* kTypeStrings[] = { + "unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -210,6 +235,7 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(double), 8 * sizeof(uint32_t), 8 * sizeof(uint64_t), + 8 * sizeof(I8Stream), }; static inline bool EnumValid(Type type) { @@ -234,6 +260,8 @@ Type TypeEnum() { return Type::kU32; } else if constexpr (hwy::IsSame()) { return Type::kU64; + } else if constexpr (hwy::IsSame()) { + return Type::kI8; } else { HWY_DASSERT(false); return Type::kUnknown; @@ -254,7 +282,9 @@ const char* TypeName() { template constexpr bool IsCompressed() { - return hwy::IsSameEither, SfpStream, NuqStream>(); + return hwy::IsSame, SfpStream>() || + hwy::IsSame, NuqStream>() || + hwy::IsSame, I8Stream>(); } // Returns the number of `MatT` elements required to store `capacity` values, @@ -265,6 +295,8 @@ template constexpr size_t CompressedArrayElements(size_t capacity) { if constexpr (hwy::IsSame, NuqStream>()) { return NuqStream::PackedEnd(capacity); + } else if constexpr (hwy::IsSame, I8Stream>()) { + return I8Stream::PackedEnd(capacity); } else { return capacity; } diff --git a/gemma/attention.cc b/gemma/attention.cc index bf39702..1269e53 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -143,8 +143,8 @@ void SingleDotSoftmaxWeightedSum( // Apply rope and scaling to Q. if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), q, layer.layer_config.qkv_dim, - p, worker); + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, + layer.layer_config.qkv_dim, p, worker); }); } @@ -315,8 +315,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim, - env.ctx.profiler, worker); + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32, + qkv_dim, env.ctx.profiler, worker); }); } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index df6efd1..c6a2fba 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -114,8 +114,8 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, // Apply rope and scaling to Q. if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), q_row, - layer.layer_config.qkv_dim, ctx.profiler, worker); + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, + layer.layer_config.qkv_dim, ctx.profiler, worker); }); } PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ecfbe47..0034f3f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -59,7 +59,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1, return; }; // Has multiplier, Gelu(c1) * c2. - Decompress1AndCompressInplace(DF(), c1, count, c2, + Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); @@ -101,8 +101,9 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1, for (size_t ir = 0; ir < range_r.Num(); ++ir) { Decompress1AndCompressInplace( DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir), - [](DF df, VF v1, VF v2) - HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); + /*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF { + return hn::Mul(v2, Gelu(df, v1)); + }); } } diff --git a/gemma/model_store.cc b/gemma/model_store.cc index a20caf2..2f3e1ec 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -112,6 +112,8 @@ class TypePrefix { return Type::kSFP; case '2': return Type::kNUQ; + case 'I': + return Type::kI8; default: // The other types were not written to pre-2025 files, hence no need to // encode and check for them here. diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h index d2b25d9..6becb29 100644 --- a/gemma/tensor_info.h +++ b/gemma/tensor_info.h @@ -46,7 +46,7 @@ struct TensorInfo { // The highest permissible compression for this tensor. The default is // kNUQ, which provides maximum compression. Other values such as kBF16 // or kF32 can be used to limit the compression to a specific type. - Type min_size = Type::kNUQ; + Type min_size = Type::kI8; // Whether to apply scaled softplus to the data. bool scaled_softplus = false; // Whether the columns or the rows take any extra dimensions. diff --git a/gemma/vit.cc b/gemma/vit.cc index d21be16..abe0a37 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -332,8 +332,8 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0), - vit_model_dim, env.ctx.profiler, + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, + activations.x.Row(0), vit_model_dim, env.ctx.profiler, hwy::Profiler::GlobalIdx()); }); } diff --git a/gemma/weights.cc b/gemma/weights.cc index cd8875b..d871c6f 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -147,15 +147,222 @@ void LayerWeightsPtrs::SplitAttW1() { qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); } +static void HWY_MAYBE_UNUSED InitAttWeightsI8( + const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, + MatPtrT& att_weights, std::vector& mat_owners, + const Allocator& allocator) { + if (!attn_vec_einsum_w.HasPtr()) return; + HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8); + + att_weights.SetType(Type::kI8); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked); + } + + const size_t model_dim = layer_config.model_dim; + const size_t heads = layer_config.heads; + const size_t qkv_dim = layer_config.qkv_dim; + + // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. + hwy::AlignedFreeUniquePtr attn_vec_einsum_w_tmp = + hwy::AllocateAligned(model_dim * heads * qkv_dim); + hwy::AlignedFreeUniquePtr att_weights_tmp = + hwy::AllocateAligned(model_dim * heads * qkv_dim); + + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0, + attn_vec_einsum_w_tmp.get(), + model_dim * heads * qkv_dim); + + for (size_t m = 0; m < model_dim; ++m) { + float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim; + for (size_t h = 0; h < heads; ++h) { + hwy::CopyBytes( + attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim, + out_row + h * qkv_dim, qkv_dim * sizeof(float)); + } + } + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, + work, att_weights.Span(), + /*packed_ofs=*/0, pool); + + att_weights.SetScale(attn_vec_einsum_w.Scale()); +} + +static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config, + MatPtrT& gating_einsum_w, + MatPtrT& gating_einsum_w1, + MatPtrT& gating_einsum_w2, + std::vector& mat_owners, + const Allocator& allocator) { + // Files have both or neither of w1 and w2. + HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); + // w is mutually exclusive with w1 and w2 in the file. + HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr()); + // Done if we already read split tensors. + if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; + // Nothing to do if w is not present. + if (!gating_einsum_w.HasPtr()) return; + + HWY_ASSERT(gating_einsum_w.GetType() == Type::kI8); + + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + const size_t model_dim = gating_einsum_w.Cols(); + HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Cols() == model_dim); + HWY_ASSERT(gating_einsum_w2.Cols() == model_dim); + + gating_einsum_w1.SetType(Type::kI8); + gating_einsum_w2.SetType(Type::kI8); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(gating_einsum_w1, allocator, + MatPadding::kPacked); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(gating_einsum_w2, allocator, + MatPadding::kPacked); + } + + const size_t total_size = gating_einsum_w.Rows() * gating_einsum_w.Cols(); + hwy::AlignedFreeUniquePtr w_tmp = + hwy::AllocateAligned(total_size); + + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, gating_einsum_w.Span(), 0, + w_tmp.get(), total_size); + + const size_t split_size = ff_hidden_dim * model_dim; + float* w1_tmp = w_tmp.get(); + float* w2_tmp = w_tmp.get() + split_size; + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0, + pool); + HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0, + pool); + + gating_einsum_w1.SetScale(1.0f); + gating_einsum_w2.SetScale(1.0f); + + gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols()); +} + +static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config, + MatPtrT& qkv_einsum_w, + MatPtrT& qkv_einsum_w1, + MatPtrT& qkv_einsum_w2, + std::vector& mat_owners, + const Allocator& allocator) { + // w is mutually exclusive with w1 in the file. + HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); + // Done if we already read split tensors. + if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; + // Nothing to do if w is not present. + if (!qkv_einsum_w.HasPtr()) return; + + HWY_ASSERT(qkv_einsum_w.GetType() == Type::kI8); + + const size_t model_dim = qkv_einsum_w.Cols(); + const size_t w1_rows = layer_config.heads * layer_config.qkv_dim; + const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; + HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); + HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); + HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows); + HWY_ASSERT(qkv_einsum_w1.Cols() == model_dim); + HWY_ASSERT(qkv_einsum_w2.Cols() == model_dim); + + qkv_einsum_w1.SetType(Type::kI8); + qkv_einsum_w2.SetType(Type::kI8); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(qkv_einsum_w1, allocator, + MatPadding::kPacked); + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(qkv_einsum_w2, allocator, + MatPadding::kPacked); + } + + const size_t total_size = qkv_einsum_w.Rows() * qkv_einsum_w.Cols(); + hwy::AlignedFreeUniquePtr w_tmp = + hwy::AllocateAligned(total_size); + + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, qkv_einsum_w.Span(), 0, w_tmp.get(), + total_size); + + const size_t w1_size = w1_rows * model_dim; + const size_t w2_size = w2_rows * model_dim; + float* w1_tmp = w_tmp.get(); + float* w2_tmp = w_tmp.get() + w1_size; + + CompressWorkingSet work; + hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool); + HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool); + + qkv_einsum_w1.SetScale(1.0f); + qkv_einsum_w2.SetScale(1.0f); + + qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); +} + // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. void LayerWeightsPtrs::Fixup(std::vector& mat_owners, const Allocator& allocator) { - // TODO(janwas): handle NUQ - InitAttWeights(mat_owners, allocator); - SplitW1(); - SplitAttW1(); + if (attn_vec_einsum_w.GetType() == Type::kI8) { + MatPtrT attn_vec_einsum_w_i8(attn_vec_einsum_w); + MatPtrT att_weights_i8(att_weights); + InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8, + mat_owners, allocator); + attn_vec_einsum_w = attn_vec_einsum_w_i8; + att_weights = att_weights_i8; + } else { + InitAttWeights(mat_owners, allocator); + } + + if (gating_einsum_w.GetType() == Type::kI8) { + MatPtrT gating_einsum_w_i8(gating_einsum_w); + MatPtrT gating_einsum_w1_i8(gating_einsum_w1); + MatPtrT gating_einsum_w2_i8(gating_einsum_w2); + SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8, + gating_einsum_w2_i8, mat_owners, allocator); + gating_einsum_w = gating_einsum_w_i8; + gating_einsum_w1 = gating_einsum_w1_i8; + gating_einsum_w2 = gating_einsum_w2_i8; + } else { + SplitW1(); + } + + if (qkv_einsum_w.GetType() == Type::kI8) { + MatPtrT qkv_einsum_w_i8(qkv_einsum_w); + MatPtrT qkv_einsum_w1_i8(qkv_einsum_w1); + MatPtrT qkv_einsum_w2_i8(qkv_einsum_w2); + SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8, + qkv_einsum_w2_i8, mat_owners, allocator); + qkv_einsum_w = qkv_einsum_w_i8; + qkv_einsum_w1 = qkv_einsum_w1_i8; + qkv_einsum_w2 = qkv_einsum_w2_i8; + } else { + SplitAttW1(); + } } static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( @@ -427,8 +634,6 @@ static void ReadAllToBF16(const std::vector& tensors, static std::vector MakeBatches( const std::vector& tensors, const uint64_t file_bytes) { PROFILER_ZONE("Startup.Weights.MakeBatches"); - // Batches must be contiguous but blobs are padded, hence at least one - // batch per tensor, and more when tensor rows exceed the batch size. std::vector batches; batches.reserve(tensors.size()); @@ -439,17 +644,28 @@ static std::vector MakeBatches( HWY_ASSERT(range.End() <= file_bytes); batches.emplace_back(offset, range.key_idx); - const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes(); - const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes(); - uint8_t* row_bytes = mat.RowBytes(0); - for (size_t r = 0; r < mat.Rows(); ++r) { - if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch. - batches.emplace_back(offset, range.key_idx); - // Adding to an empty batch is always successful. - HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); + if (mat.IsPacked()) { + HWY_ASSERT(range.bytes == mat.PackedBytes()); + if (!batches.back().Add(mat.Packed(), range.bytes)) { + // This should not happen if tensors are < 2GB. + // If it does, we need to chunk. For now, let's assume it doesn't. + HWY_ABORT("Packed tensor too large for a single IO batch."); + } + offset += range.bytes; + } else { + const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes(); + const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes(); + uint8_t* row_bytes = mat.RowBytes(0); + for (size_t r = 0; r < mat.Rows(); ++r) { + if (!batches.back().Add(row_bytes, + file_bytes_per_row)) { // Full batch. + batches.emplace_back(offset, range.key_idx); + // Adding to an empty batch is always successful. + HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); + } + offset += file_bytes_per_row; + row_bytes += mem_stride_bytes; } - offset += file_bytes_per_row; - row_bytes += mem_stride_bytes; } HWY_ASSERT(offset == range.End()); } diff --git a/ops/matmul_static.h b/ops/matmul_static.h index 6b93d92..d2ab677 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -50,6 +50,7 @@ GEMMA_MATMUL_FOR_B(float) \ GEMMA_MATMUL_FOR_B(NuqStream) \ GEMMA_MATMUL_FOR_B(SfpStream) \ + GEMMA_MATMUL_FOR_B(I8Stream) \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/ops/matmul_static_i8.cc b/ops/matmul_static_i8.cc new file mode 100644 index 0000000..b21bc27 --- /dev/null +++ b/ops/matmul_static_i8.cc @@ -0,0 +1,29 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/matmul_static_i8.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_MATMUL_TB I8Stream +#include "ops/matmul_static-inl.h" \ No newline at end of file diff --git a/ops/ops-inl.h b/ops/ops-inl.h index c966a68..4ff2c7d 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -220,6 +220,7 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, + const size_t w_ofs, OT* HWY_RESTRICT out, const size_t size, hwy::Profiler& p, const size_t worker) { @@ -232,7 +233,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker)); const VF* HWY_RESTRICT pmul = &mul; - Decompress2AndCompressTo(DF(), out, size, x, weight, + Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs, [pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF { const VF m = hn::Mul(*pmul, vx); // (1+weight) * m = m + weight*m = one FMA. @@ -242,13 +243,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, - XT* HWY_RESTRICT inout, - const size_t size, - hwy::Profiler& p, - const size_t worker) { +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout, + const size_t size, hwy::Profiler& p, const size_t worker) { PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace)); - namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -256,7 +254,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker)); const VF* HWY_RESTRICT pmul = &mul; - Decompress1AndCompressInplace(DF(), inout, size, weight, + Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs, [pmul](DF /*df*/, VF vx, VF vw) HWY_ATTR -> VF { const VF m = hn::Mul(*pmul, vx); // (1+weight) * m = m + weight*m = one FMA. @@ -489,7 +487,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - Decompress1AndCompressInplace(DF(), out, size, x, + Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0, [&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF { return hn::Add(x, out); }); } @@ -507,8 +505,8 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, cluster_idx, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), - out.Row(token_idx), activations.Cols(), ctx.profiler, - worker); + /*w_ofs=*/0, out.Row(token_idx), activations.Cols(), + ctx.profiler, worker); }); }); } @@ -522,7 +520,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, CallUpcasted(&weights, [&](const auto* weights_t) { ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, [&](uint64_t token_idx, size_t worker) { - RMSNormInplace(weights_t->PackedScale1(), + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, inout.Row(token_idx), inout.Cols(), ctx.profiler, worker); }); @@ -604,7 +602,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c, const VF vc = hn::Set(DF(), c); const VF* HWY_RESTRICT pc = &vc; - Decompress1AndCompressInplace(DF(), out, size, x, + Decompress1AndCompressInplace(DF(), out, size, x, /*p1_ofs=*/0, [&](DF /*df*/, VF out, VF x) HWY_ATTR -> VF { return hn::MulAdd(x, *pc, out); }); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index dd8e4e8..d46bb5c 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -558,7 +558,8 @@ struct TestRMSNorm { ScalarRMSNorm(vec, weight, expected, kSize); InitProfilerZones(hwy::Profiler::Get()); - RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); + RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(), + /*worker=*/0); for (size_t i = 0; i < kSize; i++) { const float e = hwy::ConvertScalarTo(expected[i]); @@ -593,7 +594,7 @@ struct TestRMSNormInplace { ScalarRMSNorm(expected, weight, expected, kSize); InitProfilerZones(hwy::Profiler::Get()); - RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(), + RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); for (size_t i = 0; i < kSize; i++) { diff --git a/python/configs.cc b/python/configs.cc index 086c691..e544bb0 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -53,7 +53,11 @@ PYBIND11_MODULE(configs, py_module) { .value("kF32", Type::kF32) .value("kBF16", Type::kBF16) .value("kSFP", Type::kSFP) - .value("kNUQ", Type::kNUQ); + .value("kNUQ", Type::kNUQ) + .value("kF64", Type::kF64) + .value("kU32", Type::kU32) + .value("kU64", Type::kU64) + .value("kI8", Type::kI8); enum_(py_module, "LayerAttentionType") .value("kGemma", LayerAttentionType::kGemma) diff --git a/util/basics.h b/util/basics.h index 0211a0e..5a7f0d5 100644 --- a/util/basics.h +++ b/util/basics.h @@ -59,6 +59,25 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) { #endif } +static inline void MaybePrintInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + __msan_print_shadow(ptr, size); +#else + (void)ptr; + (void)size; +#endif +} + +static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + return __msan_test_shadow(ptr, size); +#else + (void)ptr; + (void)size; + return 0; +#endif +} + // Shared between gemma.h and ops-inl.h. #pragma pack(push, 1) struct TokenAndProb { diff --git a/util/mat.cc b/util/mat.cc index f81767d..6d9c9bf 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -80,11 +80,13 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes, void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator, MatPadding padding) { - const bool is_nuq = mat.GetType() == Type::kNUQ; - if (is_nuq) padding = MatPadding::kPacked; + const bool is_compressed_and_packed = + mat.GetType() == Type::kNUQ || mat.GetType() == Type::kI8; + if (is_compressed_and_packed) padding = MatPadding::kPacked; const size_t stride = Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes()); - const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; + const size_t num = + is_compressed_and_packed ? mat.PackedBytes() : mat.Rows() * stride; // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // might not be enough, hence add extra. `MatT` is at least one byte, which // is half of BF16, hence adding `VectorBytes` *elements* is enough. diff --git a/util/mat.h b/util/mat.h index 6f9a243..59eceaa 100644 --- a/util/mat.h +++ b/util/mat.h @@ -240,6 +240,8 @@ class MatPtr : public IFields { // `CompressedArrayElements` is a wrapper function that has the same // effect, but that requires a template argument, not `type`. num_elements = NuqStream::PackedEnd(num_elements); + } else if (type == Type::kI8) { + num_elements = I8Stream::PackedEnd(num_elements); } return num_elements; } @@ -324,7 +326,8 @@ class MatPtrT : public MatPtr { } PackedSpan PaddedSpan() const { - return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride()); + const size_t num = IsPacked() ? num_elements_ : Rows() * Stride(); + return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num); } // For `compress-inl.h` functions, which assume contiguous streams and thus @@ -379,6 +382,9 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, } else if (base->GetType() == Type::kSFP) { const MatPtrT mat(*base); return func(&mat, std::forward(args)...); + } else if (base->GetType() == Type::kI8) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); } @@ -410,6 +416,10 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, const MatPtrT mat1(*base1); const MatPtrT mat2(*base2); return func(&mat1, &mat2, std::forward(args)...); + } else if (base1->GetType() == Type::kI8) { + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType())); } From 9b6ed1a58f631c85693117f38b1fea36c7e82a2f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 15 Oct 2025 15:45:27 -0700 Subject: [PATCH 64/65] gemma_batch_bench: generate more unique prompts PiperOrigin-RevId: 819944137 --- evals/gemma_batch_bench.cc | 99 +++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 34 deletions(-) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index ff81671..4a6f5ea 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -15,6 +15,7 @@ #include +#include #include #include @@ -48,48 +49,78 @@ class GemmaBatchBench : public ::testing::Test { }; TEST_F(GemmaBatchBench, RandomQuestionsBatched) { - const std::vector questions = { - {"Write me a poem about Australia?"}, - {"What's the history of Denmark?"}, - {"Write me a comedy story about the USA."}, - {"Teach me about GPU programming."}, - {"Write me a story about the moon."}, - {"Write me a story about the universe."}, - {"Write a poem about planet earth."}, - {"Tell me more about olympic sports."}, - {"How would you describe Washington State?"}, - {"Write me a story about Silicon Valley."}, - {"Write me about your best friend."}, - {"How would you describe a unicorn?"}, - {"Tell me about world war history."}, - {"Tell me about Google."}, + std::vector prompts = { + {"Describe dynamic programming."}, + {"Explain how electric cars work."}, {"Explain to me how to use Google Maps."}, - {"Explain to me how AI works."}, - {"Write me a poem about France."}, - {"What's the history of Great Britain?"}, - {"Write me a comedy story about Florida."}, - {"Teach me about dynamic programming."}, - {"Write me a story about Jupiter."}, - {"Write me a story about space ships."}, - {"Write a poem about some random planet."}, - {"Tell me more about team sports."}, - {"How would you describe Michigan State?"}, - {"Write me a story about Europe."}, - {"Write me about your best colleague."}, - {"How would you describe a horse?"}, - {"Tell me about World War 2."}, + {"How does AI work?"}, + {"How would you describe a unicorn?"}, {"Please share some good cooking tips."}, - {"Tell me about space travel."}, - {"Explain to me how electric cars work."}, + {"Teach me about GPU programming."}, + {"Tell me a fact about World War 2."}, + {"Tell me about Google."}, + {"Tell me more about olympic sports."}, + {"Tell me something about space travel."}, + {"What is a horse?"}, + {"What is Michigan State?"}, + {"What's the history of Denmark?"}, + {"Write a poem about planet earth."}, + {"Write a story about Jupiter."}, + {"Write about the moon."}, + {"Write me a comedy story about Florida."}, + {"Write me a poem about France."}, }; + const std::vector start = { + {"What is"}, {"When did"}, {"Where did"}, {"How did"}, {"Why did"}}; + const std::vector concepts = {"Socrates", + "Einstein", + "Leonardo", + "Cleopatra", + "Adele", + "Mars", + "Turing", + "Mozart", + "democracy", + "gravity", + "AI", + "evolution", + "physics", + "the internet", + "steam engine", + "inflation", + "electricity", + "the Sahara", + "NASA", + "Rome", + "the UN", + "Google", + "the Renaissance", + "Hamlet", + "poetry", + "Stoicism", + "geometry", + "DNA", + "Star Wars", + "1984"}; + const std::vector end = {"exist?", "work?", "happen?", + "lead to?", "believe?", "result in?"}; + for (const std::string& s : start) { + for (const std::string& c : concepts) { + for (const std::string& e : end) { + prompts.push_back(s + " " + c + " " + e); + } + } + } + AesCtrEngine engine(true); + std::shuffle(prompts.begin(), prompts.end(), RngStream(engine, 123)); - // Fills prompts round robin from `questions` until the desired batch size. + // Fills `inputs` by repeating from `prompts` until the desired batch size. std::vector inputs; inputs.reserve(s_env->MutableConfig().decode_qbatch_size); size_t qpos = 0; for (size_t i = 0; i < inputs.capacity(); ++i) { - inputs.push_back(questions[qpos++]); - if (qpos == questions.size()) qpos = 0; + inputs.push_back(prompts[qpos++]); + if (qpos == prompts.size()) qpos = 0; } s_env->SetMaxGeneratedTokens(24); std::vector responses = BatchGemmaReply(inputs); From f59eb2ed72ece1f2494506d9164cf514fec6f31a Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 16 Oct 2025 04:00:06 -0700 Subject: [PATCH 65/65] Remove multi-package support from topology Also no longer assume equal-sized clusters PiperOrigin-RevId: 820164125 --- io/blob_compare.cc | 65 +++++---- ops/dot_test.cc | 5 +- ops/matmul.cc | 6 +- ops/matmul.h | 23 ++-- ops/matmul_test.cc | 37 +++-- util/allocator.cc | 2 +- util/allocator.h | 2 +- util/threading.cc | 127 +++++++---------- util/threading.h | 130 +++++------------- util/threading_context.cc | 19 ++- util/threading_context.h | 15 +-- util/threading_test.cc | 39 +++--- util/topology.cc | 277 ++++++++++++++++---------------------- util/topology.h | 51 ++----- 14 files changed, 305 insertions(+), 493 deletions(-) diff --git a/io/blob_compare.cc b/io/blob_compare.cc index bb25843..998036e 100644 --- a/io/blob_compare.cc +++ b/io/blob_compare.cc @@ -28,7 +28,6 @@ #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" namespace gcpp { @@ -104,27 +103,31 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) { // Reads one set of blobs in parallel (helpful if in disk cache). // Aborts on error. void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs, - hwy::ThreadPool& pool) { + ThreadingContext& ctx, size_t cluster_idx) { HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size()); - pool.Run(0, blobs.size(), [&](size_t i, size_t /*thread*/) { - HWY_ASSERT(ranges[i].bytes == blobs[i].size()); - reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data()); - }); + ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, + cluster_idx, [&](size_t i, size_t /*thread*/) { + HWY_ASSERT(ranges[i].bytes == blobs[i].size()); + reader.file().Read(ranges[i].offset, ranges[i].bytes, + blobs[i].data()); + }); } // Parallelizes ReadBlobs across (two) packages, if available. void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, const RangeVec& ranges1, const RangeVec& ranges2, size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2, - NestedPools& pools) { + ThreadingContext& ctx) { const double t0 = hwy::platform::Now(); - HWY_WARN("Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30, - pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers()); - pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) { - ReadBlobs(task ? reader2 : reader1, task ? ranges2 : ranges1, - task ? blobs2 : blobs1, pools.Pool(pkg_idx)); - }); + HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, + ctx.pools.NumClusters()); + ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, + [&](const size_t task, size_t cluster_idx) { + ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, + task ? blobs1 : blobs2, ctx, cluster_idx); + }); + const double t1 = hwy::platform::Now(); HWY_WARN("%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9); } @@ -181,29 +184,23 @@ size_t BlobDifferences(const ByteSpan data1, const ByteSpan data2, } void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2, - size_t total_bytes, NestedPools& pools) { + size_t total_bytes, ThreadingContext& ctx) { HWY_WARN("Comparing %zu blobs in parallel: ", keys.size()); const double t0 = hwy::platform::Now(); std::atomic blobs_equal{}; std::atomic blobs_diff{}; - const IndexRangePartition ranges = StaticPartition( - IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1); - ParallelizeOneRange( - ranges, pools.AllPackages(), - [&](const IndexRange& range, size_t pkg_idx) { - pools.Pool(pkg_idx).Run( - range.begin(), range.end(), [&](size_t i, size_t /*thread*/) { - const size_t mismatches = - BlobDifferences(blobs1[i], blobs2[i], keys[i]); - if (mismatches != 0) { - HWY_WARN("key %s has %zu mismatches in %zu bytes!\n", - keys[i].c_str(), mismatches, blobs1[i].size()); - blobs_diff.fetch_add(1); - } else { - blobs_equal.fetch_add(1); - } - }); - }); + ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, + [&](size_t i, size_t /*thread*/) { + const size_t mismatches = + BlobDifferences(blobs1[i], blobs2[i], keys[i]); + if (mismatches != 0) { + HWY_WARN("key %s has %zu mismatches in %zu bytes!\n", + keys[i].c_str(), mismatches, blobs1[i].size()); + blobs_diff.fetch_add(1); + } else { + blobs_equal.fetch_add(1); + } + }); const double t1 = hwy::platform::Now(); HWY_WARN("%.1f GB/s; total blob matches=%zu, mismatches=%zu\n", total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(), @@ -230,9 +227,9 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) { ThreadingArgs args; ThreadingContext ctx(args); ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2, - ctx.pools); + ctx); - CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx.pools); + CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx); } } // namespace gcpp diff --git a/ops/dot_test.cc b/ops/dot_test.cc index d93b210..ed09429 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1124,8 +1124,9 @@ void TestAllDot() { MatPadding::kOdd); std::array all_stats; - ctx.pools.Cluster(0, 0).Run( - 0, kReps, [&](const uint32_t rep, size_t thread) { + ParallelFor( + ParallelismStrategy::kWithinCluster, kReps, ctx, 0, + [&](size_t rep, size_t thread) { float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pb = b.Row(thread); double* HWY_RESTRICT buf = bufs.Row(thread); diff --git a/ops/matmul.cc b/ops/matmul.cc index 6ef1412..ebeff9b 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -351,7 +351,7 @@ std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) { - const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); + const size_t num_clusters = ctx.pools.NumClusters(); per_cluster.resize(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C @@ -368,7 +368,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { PROFILER_ZONE("Startup.BindB"); - const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node(); + const size_t node = ctx.topology.GetCluster(0).Node(); uintptr_t begin = reinterpret_cast(B.RowBytes(0)); uintptr_t end = begin + B.Rows() * B.Stride() * B.ElementBytes(); // B row padding is less than the page size, so only bind the subset that @@ -394,7 +394,7 @@ void BindC(ThreadingContext& ctx, MatPtr& C) { const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), allocator.BasePageBytes()); - const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node(); + const size_t node = ctx.topology.GetCluster(0).Node(); bool ok = true; for (size_t im = 0; im < C.Rows(); ++im) { ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); diff --git a/ops/matmul.h b/ops/matmul.h index bedee3d..e16c0f2 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -105,8 +105,7 @@ struct MMParallelWithinCluster { size_t inner_tasks, size_t cluster_idx, const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t pkg_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const size_t base = ctx.Worker(cluster_idx); const IndexRangePartition ranges_n = StaticPartition( @@ -122,8 +121,7 @@ struct MMParallelWithinCluster { const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, size_t cluster_idx, const Func& func) const { - const size_t pkg_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const size_t base = ctx.Worker(cluster_idx); // Low-batch: avoid Divide/Remainder. @@ -143,8 +141,7 @@ struct MMParallelWithinCluster { template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { - const size_t pkg_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const size_t base = ctx.Worker(cluster_idx); cluster.Run( @@ -164,12 +161,11 @@ struct MMParallelHierarchical { HWY_DASSERT(caller_cluster_idx == 0); // Single cluster: parallel-for over static partition of `range_n`. - const size_t pkg_idx = 0; - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); return ParallelizeOneRange( @@ -185,7 +181,7 @@ struct MMParallelHierarchical { ParallelizeOneRange( ranges_n, all_clusters, [&](const IndexRange& n_range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const size_t cluster_base = ctx.Worker(cluster_idx); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( @@ -206,17 +202,16 @@ struct MMParallelHierarchical { const IndexRangePartition& ranges_nc, HWY_MAYBE_UNUSED size_t caller_cluster_idx, const Func& func) const { - const size_t pkg_idx = 0; HWY_DASSERT(caller_cluster_idx == 0); - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); // `all_clusters` is a pool with one worker per cluster in a package. const size_t num_clusters = all_clusters.NumWorkers(); // Single (big) cluster: collapse two range indices into one parallel-for // to reduce the number of fork-joins. if (num_clusters == 1) { const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( @@ -237,7 +232,7 @@ struct MMParallelHierarchical { ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { const size_t cluster_base = ctx.Worker(cluster_idx); - hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); ParallelizeOneRange(ranges_mc, cluster, [&](const IndexRange& range_mc, size_t worker) { func(range_mc, range_nc, cluster_base + worker); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 2f0fde2..101707f 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -191,29 +191,22 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, const IndexRange all_cols_c(0, C.Cols()); NestedPools& pools = env.ctx.pools; - hwy::ThreadPool& all_packages = pools.AllPackages(); - const IndexRangePartition get_row_c = - StaticPartition(all_rows_c, all_packages.NumWorkers(), 1); + hwy::ThreadPool& all_clusters = pools.AllClusters(); + const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB); + const IndexRangePartition get_col_c = + StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); ParallelizeOneRange( - get_row_c, all_packages, - [&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR { - hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx); - const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB); - const IndexRangePartition get_col_c = - StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); - ParallelizeOneRange( - get_col_c, all_clusters, - [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { - for (size_t r : rows_c) { - TC* HWY_RESTRICT C_row = C.Row(r); - for (size_t c : cols_c) { - const float add = add_row ? add_row[c] : 0.0f; - const float dot = - Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols()); - C_row[c] = hwy::ConvertScalarTo(add + scale * dot); - } - } - }); + get_col_c, all_clusters, + [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { + for (size_t r : all_rows_c) { + TC* HWY_RESTRICT C_row = C.Row(r); + for (size_t c : cols_c) { + const float add = add_row ? add_row[c] : 0.0f; + const float dot = + Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols()); + C_row[c] = hwy::ConvertScalarTo(add + scale * dot); + } + } }); } diff --git a/util/allocator.cc b/util/allocator.cc index f99586e..612bbb9 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -139,7 +139,7 @@ CacheInfo::CacheInfo(const BoundedTopology& topology) { step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); - const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0); + const BoundedTopology::Cluster& cluster = topology.GetCluster(0); if (const hwy::Cache* caches = hwy::DataCaches()) { l1_bytes_ = caches[1].size_kib << 10; l2_bytes_ = caches[2].size_kib << 10; diff --git a/util/allocator.h b/util/allocator.h index 086b6e9..d508d5c 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -169,7 +169,7 @@ class Allocator { bool ShouldBind() const { return should_bind_; } // Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is - // typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. + // typically `BoundedTopology::GetCluster(cluster_idx).node`. // Writes zeros to SOME of the memory. Only call if `ShouldBind()`. // `p` and `bytes` must be multiples of `QuantumBytes()`. bool BindMemory(void* p, size_t bytes, size_t node) const; diff --git a/util/threading.cc b/util/threading.cc index 9c4cfe0..6d4a603 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -18,7 +18,6 @@ #include -#include // std::sort #include #include #include @@ -31,14 +30,6 @@ namespace gcpp { -// Sort T := packages/clusters by descending 'size' so that users who only use -// one Group get the largest. -template -static void SortByDescendingSize(std::vector& groups) { - std::sort(groups.begin(), groups.end(), - [](const T& a, const T& b) { return a.Size() > b.Size(); }); -} - static bool InContainer() { return false; // placeholder for container detection, do not remove } @@ -55,19 +46,18 @@ PinningPolicy::PinningPolicy(Tristate pin) { // If `pinning.Want()`, tries to pin each worker in `pool` to an LP in // `cluster`, and calls `pinning.NotifyFailed()` if any fails. -void MaybePin(const BoundedTopology& topology, size_t pkg_idx, - size_t cluster_idx, const BoundedTopology::Cluster& cluster, - PinningPolicy& pinning, hwy::ThreadPool& pool) { +static void MaybePin(const BoundedTopology& topology, size_t cluster_idx, + const BoundedTopology::Cluster& cluster, + PinningPolicy& pinning, hwy::ThreadPool& pool) { const std::vector lps = cluster.LPVector(); HWY_ASSERT(pool.NumWorkers() <= lps.size()); pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { HWY_ASSERT(task == thread); // each worker has one task char buf[16]; // Linux limitation - const int bytes_written = snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", - topology.SkippedPackages() + pkg_idx, - topology.SkippedClusters() + cluster_idx, - static_cast(task)); + const int bytes_written = snprintf( + buf, sizeof(buf), "P%zu X%02zu C%03d", topology.SkippedPackages(), + topology.SkippedClusters() + cluster_idx, static_cast(task)); HWY_ASSERT(bytes_written < static_cast(sizeof(buf))); hwy::SetThreadName(buf, 0); // does not support varargs @@ -113,79 +103,56 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) { return max; } -NestedPools::NestedPools(const BoundedTopology& topology, - const Allocator& allocator, size_t max_threads, - Tristate pin) - : pinning_(pin) { - packages_.resize(topology.NumPackages()); - all_packages_ = - MakePool(allocator, packages_.size(), hwy::PoolWorkerMapping()); - const size_t max_workers_per_package = - DivideMaxAcross(max_threads, packages_.size()); - // Each worker in all_packages_, including the main thread, will be the - // calling thread of an all_clusters->Run, and hence pinned to one of the - // `cluster.lps` if `pin`. - all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) { - HWY_ASSERT(pkg_idx == thread); // each thread has one task - packages_[pkg_idx] = Package(topology, allocator, pinning_, pkg_idx, - max_workers_per_package); - }); - - all_pinned_ = pinning_.AllPinned(&pin_string_); - - // For mapping package/cluster/thread to noncontiguous TLS indices, in case - // cluster/thread counts differ. - HWY_ASSERT(!packages_.empty() && packages_.size() <= 16); - for (const Package& p : packages_) { - max_clusters_per_package_ = - HWY_MAX(max_clusters_per_package_, p.NumClusters()); - max_workers_per_cluster_ = - HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster()); - } - HWY_ASSERT(max_clusters_per_package_ >= 1); - HWY_ASSERT(max_clusters_per_package_ <= 64); - HWY_ASSERT(max_workers_per_cluster_ >= 1); - HWY_ASSERT(max_workers_per_cluster_ <= 256); -} - // `max_or_zero` == 0 means no limit. static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero); } -NestedPools::Package::Package(const BoundedTopology& topology, - const Allocator& allocator, - PinningPolicy& pinning, size_t pkg_idx, - size_t max_workers_per_package) { - // Pre-allocate because elements are set concurrently. - clusters_.resize(topology.NumClusters(pkg_idx)); - const size_t max_workers_per_cluster = - DivideMaxAcross(max_workers_per_package, clusters_.size()); +NestedPools::NestedPools(const BoundedTopology& topology, + const Allocator& allocator, size_t max_threads, + Tristate pin) + : pinning_(pin) { + const size_t num_clusters = topology.NumClusters(); + const size_t cluster_workers_cap = DivideMaxAcross(max_threads, num_clusters); + + // Precompute cluster sizes to ensure we pass the same values to `MakePool`. + // The max is also used for `all_clusters_mapping`, see below. + size_t workers_per_cluster[hwy::kMaxClusters] = {}; + size_t all_clusters_node = 0; + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx); + workers_per_cluster[cluster_idx] = + CapIfNonZero(tcluster.NumWorkers(), cluster_workers_cap); + // Cluster sizes can vary because individual LPs may be disabled. Use the + // max so that `GlobalIdx` is consistent within and across clusters. It is + // OK to have holes or gaps in the worker index space. + max_workers_per_cluster_ = + HWY_MAX(max_workers_per_cluster_, workers_per_cluster[cluster_idx]); + all_clusters_node = tcluster.Node(); // arbitrarily use the last node seen + } - const BoundedTopology::Cluster& cluster0 = topology.GetCluster(pkg_idx, 0); - // Core 0 of each cluster. The second argument is the cluster size, not - // number of clusters. We ensure that it is the same for all clusters so that - // the `GlobalIdx` computation is consistent within and across clusters. const hwy::PoolWorkerMapping all_clusters_mapping(hwy::kAllClusters, - cluster0.Size()); - all_clusters_ = MakePool(allocator, clusters_.size(), all_clusters_mapping, - cluster0.Node()); + max_workers_per_cluster_); + all_clusters_ = MakePool(allocator, num_clusters, all_clusters_mapping, + all_clusters_node); + + // Pre-allocate because elements are set concurrently. + clusters_.resize(num_clusters); + // Parallel so we also pin the calling worker in `all_clusters` to // `cluster.lps`. - all_clusters_->Run( - 0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) { - HWY_ASSERT(cluster_idx == thread); // each thread has one task - const BoundedTopology::Cluster& cluster = - topology.GetCluster(pkg_idx, cluster_idx); - HWY_ASSERT(cluster.Size() == cluster0.Size()); - clusters_[cluster_idx] = MakePool( - allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster), - hwy::PoolWorkerMapping(cluster_idx, cluster.Size()), - cluster.Node()); - // Pin workers AND the calling thread from `all_clusters`. - MaybePin(topology, pkg_idx, cluster_idx, cluster, pinning, - *clusters_[cluster_idx]); - }); + all_clusters_->Run(0, num_clusters, [&](size_t cluster_idx, size_t thread) { + HWY_ASSERT(cluster_idx == thread); // each thread has one task + const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx); + clusters_[cluster_idx] = + MakePool(allocator, workers_per_cluster[cluster_idx], + hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_), + tcluster.Node()); + // Pin workers AND the calling thread from `all_clusters_`. + MaybePin(topology, cluster_idx, tcluster, pinning_, + *clusters_[cluster_idx]); + }); + all_pinned_ = pinning_.AllPinned(&pin_string_); } } // namespace gcpp diff --git a/util/threading.h b/util/threading.h index 53795be..35c6e22 100644 --- a/util/threading.h +++ b/util/threading.h @@ -66,17 +66,14 @@ class PinningPolicy { }; // PinningPolicy // Creates a hierarchy of thread pools according to `BoundedTopology`: one with -// a thread per enabled package; for each of those, one with a thread per -// enabled cluster (CCX/shared L3), and for each of those, the remaining -// enabled cores in that cluster. +// a thread per enabled cluster (CCX/shared L3), and for each of those, the +// remaining enabled cores in that cluster. // // Note that we support spin waits, thus it is important for each thread to be // responsive, hence we do not create more than one thread per enabled core. -// For example, when there are two packages with four clusters of 8 cores, -// `AllPackages` has the main thread plus one extra thread, each `AllClusters` -// has one of the `AllPackages` threads plus three extras, each `Cluster` runs -// on one `AllClusters` thread plus seven extra workers, for a total of -// 1 + 2*3 + 2*(4*7) = 63 extras plus the main thread. +// For example, when there are four clusters of 8 cores, `AllClusters` has the +// main thread plus three extras, each `Cluster` runs on one of `AllClusters` +// plus seven extras, for a total of 3 + (4*7) = 31 extras plus the main thread. // // Useful when there are tasks which should be parallelized by workers sharing a // cache, or on the same NUMA node. In both cases, individual pools have lower @@ -96,6 +93,10 @@ class NestedPools { NestedPools(NestedPools&&) = delete; NestedPools& operator=(NestedPools&&) = delete; + // Because cross-package latency is high, this interface assumes only one + // package is used. The `skip_packages` argument to `BoundedTopology` selects + // which package that is for this `NestedPools` instance. + // // `max_threads` is the maximum number of threads to divide among all // clusters. This is more intuitive than a per-cluster limit for users who // may not be aware of the CPU topology. This should be zero (meaning no @@ -104,8 +105,8 @@ class NestedPools { // // To ensure we do not create more threads than there are HW cores, which // would cause huge slowdowns when spinning, the `BoundedSlice` arguments - // only impose upper bounds on the number of detected packages and clusters - // rather than defining the actual number of threads. + // only impose upper bounds on the number of detected clusters rather than + // defining the actual number of threads. NestedPools(const BoundedTopology& topology, const Allocator& allocator, size_t max_threads = 0, Tristate pin = Tristate::kDefault); @@ -133,98 +134,37 @@ class NestedPools { } } - size_t NumPackages() const { return packages_.size(); } - hwy::ThreadPool& AllPackages() { return *all_packages_; } - hwy::ThreadPool& AllClusters(size_t pkg_idx) { - HWY_DASSERT(pkg_idx < NumPackages()); - return packages_[pkg_idx].AllClusters(); - } - hwy::ThreadPool& Cluster(size_t pkg_idx, size_t cluster_idx) { - HWY_DASSERT(pkg_idx < NumPackages()); - return packages_[pkg_idx].Cluster(cluster_idx); + size_t NumClusters() const { return clusters_.size(); } + hwy::ThreadPool& AllClusters() { return *all_clusters_; } + hwy::ThreadPool& Cluster(size_t cluster_idx) { + HWY_DASSERT(cluster_idx < clusters_.size()); + return *clusters_[cluster_idx]; } // Reasonably tight upper bounds for allocating thread-local storage (TLS). size_t MaxWorkersPerCluster() const { return max_workers_per_cluster_; } - size_t MaxWorkersPerPackage() const { - return max_clusters_per_package_ * MaxWorkersPerCluster(); - } - size_t MaxWorkers() const { return NumPackages() * MaxWorkersPerPackage(); } - - // Actual number of workers. - size_t TotalWorkers() const { - size_t total_workers = 0; - for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) { - total_workers += packages_[pkg_idx].TotalWorkers(); - } - return total_workers; - } + size_t MaxWorkers() const { return NumClusters() * MaxWorkersPerCluster(); } // For ShowConfig const char* PinString() const { return pin_string_; } // Returns a single pool on the given package: either one thread per cluster // if there is more than one, which maximizes available memory bandwidth, or - // the first cluster, which is typically the whole package. For use by callers - // that only have a single parallel-for. + // the first cluster, which is typically the whole package. For use by + // callers that only have a single parallel-for. + // DEPRECATED: use ParallelFor instead. hwy::ThreadPool& Pool(size_t pkg_idx = 0) { // Only one cluster: use its pool, typically a whole socket. - if (AllClusters(pkg_idx).NumWorkers() == 1) { - return Cluster(pkg_idx, 0); - } + if (NumClusters() == 1) return Cluster(0); // One worker per cluster to maximize bandwidth availability. - return AllClusters(pkg_idx); + return AllClusters(); } private: - class Package { - public: - Package() = default; // for vector - Package(const BoundedTopology& topology, const Allocator& allocator, - PinningPolicy& pinning, size_t pkg_idx, - size_t max_workers_per_package); - - size_t NumClusters() const { return clusters_.size(); } - size_t MaxWorkersPerCluster() const { - size_t max_workers_per_cluster = 0; - for (const PoolPtr& cluster : clusters_) { - max_workers_per_cluster = - HWY_MAX(max_workers_per_cluster, cluster->NumWorkers()); - } - return max_workers_per_cluster; - } - size_t TotalWorkers() const { - size_t total_workers = 0; - for (const PoolPtr& cluster : clusters_) { - total_workers += cluster->NumWorkers(); - } - return total_workers; - } - - hwy::ThreadPool& AllClusters() { return *all_clusters_; } - hwy::ThreadPool& Cluster(size_t cluster_idx) { - HWY_DASSERT(cluster_idx < clusters_.size()); - return *clusters_[cluster_idx]; - } - - void SetWaitMode(hwy::PoolWaitMode wait_mode) { - all_clusters_->SetWaitMode(wait_mode); - for (PoolPtr& cluster : clusters_) { - cluster->SetWaitMode(wait_mode); - } - } - - private: - // Must be freed after `clusters_` because it reserves threads which are - // the main threads of `clusters_`. - PoolPtr all_clusters_; - std::vector clusters_; - }; // Package - void SetWaitMode(hwy::PoolWaitMode wait_mode) { - all_packages_->SetWaitMode(wait_mode); - for (Package& package : packages_) { - package.SetWaitMode(wait_mode); + all_clusters_->SetWaitMode(wait_mode); + for (PoolPtr& cluster : clusters_) { + cluster->SetWaitMode(wait_mode); } } @@ -232,12 +172,13 @@ class NestedPools { bool all_pinned_; const char* pin_string_; - std::vector packages_; - PoolPtr all_packages_; + // Must be freed after `clusters_` because it reserves threads which are + // the main threads of `clusters_`. + PoolPtr all_clusters_; + std::vector clusters_; - // For TLS indices. One might think this belongs in BoundedTopology, but it - // depends on max_threads, which is passed to the NestedPools constructor. - size_t max_clusters_per_package_ = 0; + // Used by `PoolWorkerMapping`. This depends on the `max_threads` argument, + // hence we can only compute this here, not in `BoundedTopology`. size_t max_workers_per_cluster_ = 0; }; @@ -362,14 +303,11 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, template void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { - // Even if there are multiple packages, we only use the first. - const size_t pkg_idx = 0; - // If few tasks, run on a single cluster. Also avoids a bit of overhead if // there is only one cluster. - hwy::ThreadPool& all_clusters = pools.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = pools.AllClusters(); const size_t num_clusters = all_clusters.NumWorkers(); - hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0); + hwy::ThreadPool& cluster = pools.Cluster(0); if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); @@ -382,7 +320,7 @@ void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools, ParallelizeOneRange( ranges, all_clusters, [&](const IndexRange& range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = pools.Cluster(cluster_idx); const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster(); cluster.Run(range.begin(), range.end(), [&](uint64_t task, size_t thread) { diff --git a/util/threading_context.cc b/util/threading_context.cc index e2c4d03..0a349fc 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -79,18 +79,15 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { } static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) { - TunePool(wait_mode, pools.AllPackages()); - for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) { - hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx); - TunePool(wait_mode, clusters); + hwy::ThreadPool& clusters = pools.AllClusters(); + TunePool(wait_mode, clusters); - // Run in parallel because Turin CPUs have 16, and in real usage, we often - // run all at the same time. - clusters.Run(0, clusters.NumWorkers(), - [&](uint64_t cluster_idx, size_t /*thread*/) { - TunePool(wait_mode, pools.Cluster(pkg_idx, cluster_idx)); - }); - } + // Run in parallel because Turin CPUs have 16, and in real usage, we often + // run all at the same time. + clusters.Run(0, clusters.NumWorkers(), + [&](uint64_t cluster_idx, size_t /*thread*/) { + TunePool(wait_mode, pools.Cluster(cluster_idx)); + }); } ThreadingContext::ThreadingContext(const ThreadingArgs& args) diff --git a/util/threading_context.h b/util/threading_context.h index ff4ff62..5c55fc4 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -153,10 +153,7 @@ enum class ParallelismStrategy : uint8_t { template void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, ThreadingContext& ctx, size_t cluster_idx, const Func& func) { - HWY_DASSERT(ctx.topology.NumPackages() == 1); - const size_t pkg_idx = 0; - - HWY_DASSERT(cluster_idx < ctx.topology.NumClusters(pkg_idx)); + HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); if (cluster_idx != 0) { // If already running across clusters, only use within-cluster modes. HWY_DASSERT(parallelism == ParallelismStrategy::kNone || @@ -173,7 +170,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, } case ParallelismStrategy::kAcrossClusters: - return ctx.pools.AllClusters(pkg_idx).Run( + return ctx.pools.AllClusters().Run( 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); @@ -181,7 +178,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, // Ensure the worker argument is unique across clusters, because it is // used for TLS indexing for example in profiler.h. const size_t base = ctx.Worker(cluster_idx); - return ctx.pools.Cluster(pkg_idx, cluster_idx) + return ctx.pools.Cluster(cluster_idx) .Run(0, num_tasks, [&](uint64_t task, size_t worker) { func(task, base + worker); }); @@ -190,15 +187,15 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, case ParallelismStrategy::kFlat: { // Check for single cluster; if not, we must compute `cluster_base` for // consistent and non-overlapping worker indices. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { - return ctx.pools.Cluster(pkg_idx, cluster_idx) + return ctx.pools.Cluster(cluster_idx) .Run(0, num_tasks, [&](uint64_t task, size_t worker) { func(task, worker); }); } - return ctx.pools.AllClusters(pkg_idx).Run( + return ctx.pools.AllClusters().Run( 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { const size_t worker = ctx.Worker(cluster_idx); func(task, worker); diff --git a/util/threading_test.cc b/util/threading_test.cc index ac2746b..4cd8554 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -99,23 +99,16 @@ TEST(ThreadingTest, TestBoundedTopology) { const BoundedSlice all; const BoundedSlice one(0, 1); // All - { - BoundedTopology topology(all, all, all); - fprintf(stderr, "%s\n", topology.TopologyString()); - } - - // Max one package { BoundedTopology topology(one, all, all); fprintf(stderr, "%s\n", topology.TopologyString()); - ASSERT_EQ(1, topology.NumPackages()); } // Max one cluster { - BoundedTopology topology(all, one, all); + BoundedTopology topology(one, one, all); fprintf(stderr, "%s\n", topology.TopologyString()); - ASSERT_EQ(1, topology.NumClusters(0)); + ASSERT_EQ(1, topology.NumClusters()); } } @@ -380,24 +373,32 @@ TEST(ThreadingTest, BenchJoin) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); NestedPools& pools = ctx.pools; - // Use last package because the main thread has been pinned to it. - const size_t pkg_idx = pools.NumPackages() - 1; - measure(pools.AllPackages(), false, "block packages"); - if (pools.AllClusters(pkg_idx).NumWorkers() > 1) { - measure(pools.AllClusters(pkg_idx), false, "block clusters"); + if (pools.NumClusters() > 1) { + measure(pools.AllClusters(), false, "block clusters"); } - measure(pools.Cluster(pkg_idx, 0), false, "block in_cluster"); + measure(pools.Cluster(0), false, "block in_cluster"); if (pools.AllPinned()) { const bool kSpin = true; - measure(pools.AllPackages(), kSpin, "spin packages"); - if (pools.AllClusters(pkg_idx).NumWorkers() > 1) { - measure(pools.AllClusters(pkg_idx), kSpin, "spin clusters"); + if (pools.NumClusters() > 1) { + measure(pools.AllClusters(), kSpin, "spin clusters"); } - measure(pools.Cluster(pkg_idx, 0), kSpin, "spin in_cluster"); + measure(pools.Cluster(0), kSpin, "spin in_cluster"); } } +TEST(ThreadingTest, TestUnequalClusters) { + ThreadingArgs threading_args; + threading_args.max_lps = 13; + ThreadingContext ctx(threading_args); + const size_t last_workers = + ctx.pools.Cluster(ctx.topology.NumClusters() - 1).NumWorkers(); + const size_t max_workers = ctx.pools.MaxWorkersPerCluster(); + fprintf(stderr, "%zu clusters, last with %zu (max %zu)\n", + ctx.topology.NumClusters(), last_workers, max_workers); + HWY_ASSERT(last_workers <= max_workers); +} + } // namespace } // namespace gcpp diff --git a/util/topology.cc b/util/topology.cc index 0d32f22..f20b7f9 100644 --- a/util/topology.cc +++ b/util/topology.cc @@ -18,21 +18,12 @@ #include #include // std::sort -#include // std::move #include #include "hwy/base.h" namespace gcpp { -// Sort T := packages/clusters by descending 'size' so that users who only use -// one Group get the largest. -template -static void SortByDescendingSize(std::vector& groups) { - std::sort(groups.begin(), groups.end(), - [](const T& a, const T& b) { return a.Size() > b.Size(); }); -} - // Returns set of LPs available for use. static LPS EnabledLPs(const BoundedSlice& lp_slice) { LPS enabled_lps; @@ -88,21 +79,23 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedSlice lp_slice) : package_slice_(package_slice), cluster_slice_(cluster_slice) { + HWY_ASSERT(package_slice_.Max() == 1); const LPS enabled_lps = EnabledLPs(lp_slice); + bool topology_ok = false; #if !GEMMA_DISABLE_TOPOLOGY if (HWY_LIKELY(!topology_.packages.empty())) { - InitFromTopology(enabled_lps); + topology_ok = InitFromTopology(enabled_lps); } #endif // Topology unknown or no packages with enabled LPs: create a single // package with one cluster, and one node. - if (HWY_UNLIKELY(NumPackages() == 0)) { + if (HWY_UNLIKELY(!topology_ok)) { InitFromLPs(enabled_lps); } - HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0); + HWY_ASSERT(NumClusters() != 0 && NumNodes() != 0); } // Topology is unknown, take the given set of LPs. @@ -161,9 +154,113 @@ constexpr bool kSplitLargeClusters = false; constexpr size_t kMaxClusters = 8; constexpr size_t kMaxLPsPerCluster = 6; -// Topology is unknown, use only the given LPs which derive from OS affinity -// and `lp_slice`. -BoundedTopology::Package::Package(const LPS& enabled_lps) { +#if !GEMMA_DISABLE_TOPOLOGY + +static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) { + LPS cores; + lps.Foreach([&](size_t lp) { + if (topology.lps[lp].smt == 0) cores.Set(lp); + }); + return cores.Count(); +} + +// tcluster is a modifiable copy of the first cluster in the package. +void BoundedTopology::SplitLargeCluster(const LPS& enabled_lps, + hwy::Topology::Cluster tcluster) { + const LPS lps = clusters_[0].LPSet(); // copy so we can clear + clusters_.clear(); + + // Split `lps` into several clusters. + LPS clusters_lps[kMaxClusters]; + const size_t num_clusters = + HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster)); + size_t num_lps = 0; + lps.Foreach( + [&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); }); + HWY_DASSERT(num_lps == lps.Count()); + + // Create new clusters, just inserting the new LPS. + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + tcluster.lps = clusters_lps[cluster_idx]; + // Keep same `private_kib` and `shared_kib`. + clusters_.push_back(Cluster(enabled_lps, topology_.lps, tcluster)); + } +} + +// Main part of ctor, called when topology is known. +bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) { + const size_t tpkg_idx = package_slice_.Begin(); + HWY_ASSERT(tpkg_idx < topology_.packages.size()); + const hwy::Topology::Package& tpackage = topology_.packages[tpkg_idx]; + const std::vector& tclusters = tpackage.clusters; + if (HWY_UNLIKELY(tclusters.empty())) { + HWY_WARN("Topology: no clusters found in package %zu.", tpkg_idx); + return false; + } + + size_t max_tcluster_cores = 0; + size_t max_tcluster_lps = 0; + for (const hwy::Topology::Cluster& tcluster : tclusters) { + const size_t cores = CoresFromLPs(tcluster.lps, topology_); + const size_t lps = tcluster.lps.Count(); + max_tcluster_cores = HWY_MAX(max_tcluster_cores, cores); + max_tcluster_lps = HWY_MAX(max_tcluster_lps, lps); + } + HWY_ASSERT(max_tcluster_cores != 0); + HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores); + + // Populate `clusters` with the subset of clusters in `cluster_slice` that + // have any enabled LPs. + clusters_.reserve(cluster_slice_.Num(tclusters.size())); + cluster_slice_.Foreach("cluster", tclusters.size(), [&](size_t cluster_idx) { + const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; + Cluster cluster(enabled_lps, topology_.lps, tcluster); + + // Skip if empty, i.e. too few `enabled_lps`. + if (HWY_LIKELY(cluster.NumWorkers() != 0)) { + clusters_.push_back(cluster); + // Remember NUMA nodes that we are actually using (not just enabled). + nodes_.Set(cluster.Node()); + } + }); + if (HWY_UNLIKELY(clusters_.empty())) { + HWY_WARN("Too restrictive cluster_slice or enabled_lps, no clusters left."); + return false; + } + + if (kSplitLargeClusters && clusters_.size() == 1 && + enabled_lps.Count() >= 16) { + SplitLargeCluster(enabled_lps, tpackage.clusters[0]); + } + + // Sort by descending 'size' so that users who only use one get the largest. + std::sort(clusters_.begin(), clusters_.end(), + [](const Cluster& a, const Cluster& b) { + return a.NumWorkers() > b.NumWorkers(); + }); + + // Largest number of enabled workers in any cluster, for `topology_string_`. + // This may be less than `max_tcluster_cores` if `enabled_lps` excludes some. + size_t max_cluster_workers = 0; + for (const Cluster& c : clusters_) { + max_cluster_workers = HWY_MAX(max_cluster_workers, c.NumWorkers()); + } + HWY_ASSERT(max_cluster_workers <= max_tcluster_cores); + // Do not warn about large clusters: GNR has 40. + + snprintf(topology_string_, sizeof(topology_string_), + "%zuS %zuX %zuC %zuH, using %zuX %zuC (nodes=%zu)", + topology_.packages.size(), tclusters.size(), max_tcluster_cores, + max_tcluster_lps / max_tcluster_cores, NumClusters(), + max_cluster_workers, nodes_.Count()); + return true; +} + +#endif // !GEMMA_DISABLE_TOPOLOGY + +// Called when topology is unknown or `GEMMA_DISABLE_TOPOLOGY`. Uses only the +// given LPs which derive from OS affinity and `lp_slice`. +void BoundedTopology::InitFromLPs(const LPS& enabled_lps) { LPS clusters_lps[kMaxClusters]; const size_t num_clusters = kSplitLargeClusters @@ -178,157 +275,11 @@ BoundedTopology::Package::Package(const LPS& enabled_lps) { }); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - clusters.push_back(Cluster(clusters_lps[cluster_idx])); + clusters_.push_back(Cluster(clusters_lps[cluster_idx])); } -} - -// NOTE: caller is responsible for checking whether `clusters` is empty. -BoundedTopology::Package::Package(const LPS& enabled_lps, - const hwy::Topology& topology, size_t pkg_idx, - BoundedSlice cluster_slice) { - const hwy::Topology::Package& tpackage = topology.packages[pkg_idx]; - // Populate `clusters` with the subset of clusters in `cluster_slice` that - // have any enabled LPs. If `clusters` remains empty, the caller will - // skip this `Package`. - clusters.reserve(cluster_slice.Num(tpackage.clusters.size())); - cluster_slice.Foreach( - "cluster", tpackage.clusters.size(), [&](size_t cluster_idx) { - const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; - Cluster cluster(enabled_lps, topology.lps, tcluster); - - // Skip if empty, i.e. too few `enabled_lps`. - if (HWY_LIKELY(cluster.Size() != 0)) { - clusters.push_back(cluster); - } - }); - SortByDescendingSize(clusters); - - // If there is only one large cluster, split it into smaller ones. - if (kSplitLargeClusters && clusters.size() == 1 && - enabled_lps.Count() >= 16) { - const LPS lps = clusters[0].LPSet(); // copy so we can clear - clusters.clear(); - - // Split `lps` into several clusters. - LPS clusters_lps[kMaxClusters]; - const size_t num_clusters = - HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster)); - size_t num_lps = 0; - lps.Foreach( - [&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); }); - HWY_DASSERT(num_lps == lps.Count()); - - // Create new clusters, just inserting the new LPS. - hwy::Topology::Cluster tcluster = tpackage.clusters[0]; // modifiable copy - for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - tcluster.lps = clusters_lps[cluster_idx]; - // Keep same `private_kib` and `shared_kib`. - clusters.push_back(Cluster(enabled_lps, topology.lps, tcluster)); - } - } -} - -#if !GEMMA_DISABLE_TOPOLOGY - -static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) { - LPS cores; - lps.Foreach([&](size_t lp) { - if (topology.lps[lp].smt == 0) cores.Set(lp); - }); - return cores.Count(); -} - -// Scans hwy::Topology for clusters and their size, for use by topology_string_. -static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters, - size_t& max_tcluster_cores, - size_t& max_tcluster_lps) { - max_tclusters = 0; - max_tcluster_cores = 0; - max_tcluster_lps = 0; - for (size_t pkg_idx = 0; pkg_idx < topology_.packages.size(); ++pkg_idx) { - const std::vector& tclusters = - topology_.packages[pkg_idx].clusters; - max_tclusters = HWY_MAX(max_tclusters, tclusters.size()); - size_t tcluster_cores = 0; - size_t tcluster_lps = 0; - for (size_t cluster_idx = 0; cluster_idx < tclusters.size(); - ++cluster_idx) { - const size_t cores = CoresFromLPs(tclusters[cluster_idx].lps, topology_); - const size_t lps = tclusters[cluster_idx].lps.Count(); - tcluster_cores = HWY_MAX(tcluster_cores, cores); - tcluster_lps = HWY_MAX(tcluster_lps, lps); - } - - if (tclusters.size() > 1 && tcluster_cores > 8) { - HWY_WARN( - "Package %zu: multiple clusters with max size %zu, whereas CCX " - "only have 8, may indicate a bug in hwy::Topology.", - pkg_idx, tcluster_cores); - } - max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores); - max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps); - } - HWY_ASSERT(max_tclusters != 0); - HWY_ASSERT(max_tcluster_cores != 0); - HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores); -} - -// Main part of ctor, called when topology is known. -void BoundedTopology::InitFromTopology(const LPS& enabled_lps) { - size_t max_tclusters, max_tcluster_cores, max_tcluster_lps; - ScanTClusters(topology_, max_tclusters, max_tcluster_cores, max_tcluster_lps); - - // (Possibly empty) subset of `Topology` packages that have `enabled_lps`. - package_slice_.Foreach( - "package", topology_.packages.size(), [&](size_t pkg_idx) { - Package package(enabled_lps, topology_, pkg_idx, cluster_slice_); - // Skip if empty, i.e. too few `enabled_lps`. - if (HWY_LIKELY(!package.clusters.empty())) { - packages_.push_back(std::move(package)); - } - }); - if (NumPackages() == 0) return; - SortByDescendingSize(packages_); - - // Remember NUMA nodes that we are actually using (not just enabled). - for (const Package& p : packages_) { - for (const Cluster& c : p.clusters) { - nodes_.Set(c.Node()); - } - } - - // Scan for max BoundedTopology clusters and their size, for topology_string_. - size_t all_max_cluster_size = 0; - for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) { - size_t max_cluster_size = 0; - for (size_t cluster_idx = 0; cluster_idx < NumClusters(pkg_idx); - ++cluster_idx) { - max_cluster_size = - HWY_MAX(max_cluster_size, GetCluster(pkg_idx, cluster_idx).Size()); - } - if (NumClusters(pkg_idx) > 1 && max_cluster_size > 8) { - HWY_WARN( - "Package %zu: multiple clusters with max size %zu, whereas CCX " - "only have 8, may indicate a bug in BoundedTopology.", - pkg_idx, max_cluster_size); - } - all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size); - } - - snprintf(topology_string_, sizeof(topology_string_), - "%zuS %zuX %zuC %zuH, using %zuS %zuX %zuC (nodes=%zu)", - topology_.packages.size(), max_tclusters, max_tcluster_cores, - max_tcluster_lps / max_tcluster_cores, packages_.size(), - NumClusters(0), all_max_cluster_size, nodes_.Count()); -} - -#endif // !GEMMA_DISABLE_TOPOLOGY - -void BoundedTopology::InitFromLPs(const LPS& enabled_lps) { - packages_.push_back(Package(enabled_lps)); snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", - GetCluster(0, 0).Size()); + GetCluster(0).NumWorkers()); // Assume a single NUMA node. nodes_.Set(0); diff --git a/util/topology.h b/util/topology.h index b844bd9..d4f80cc 100644 --- a/util/topology.h +++ b/util/topology.h @@ -40,6 +40,7 @@ class BoundedSlice { BoundedSlice(size_t skip = 0, size_t max = 0) : skip_(skip), max_(max) {} size_t Begin() const { return skip_; } + size_t Max() const { return max_; } // STL-style one past the end. size_t End(size_t detected) const { @@ -82,12 +83,11 @@ using LPS = hwy::LogicalProcessorSet; // back to a single package and cluster. class BoundedTopology { public: - // Defaults to "use all detected". - BoundedTopology(BoundedSlice package_slice = BoundedSlice(), + // `package_slice` must have `Max() == 1`. Others default to "use all". + BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice = BoundedSlice(), BoundedSlice lp_slice = BoundedSlice()); - size_t NumPackages() const { return packages_.size(); } size_t NumNodes() const { return nodes_.Count(); } const char* TopologyString() const { return topology_string_; } @@ -98,8 +98,7 @@ class BoundedTopology { const std::vector& all_lps, const hwy::Topology::Cluster& tcluster); - // For SortByDescendingSize. - size_t Size() const { return num_workers_; } + size_t NumWorkers() const { return num_workers_; } // Returns vector with all enabled LPs, used for pinning. std::vector LPVector() const { @@ -127,26 +126,11 @@ class BoundedTopology { size_t shared_kib_ = 0; }; // Cluster - size_t NumClusters(size_t pkg_idx) const { - HWY_ASSERT(pkg_idx < NumPackages()); - return packages_[pkg_idx].clusters.size(); + size_t NumClusters() const { return clusters_.size(); } + const Cluster& GetCluster(size_t cluster_idx) const { + HWY_ASSERT(cluster_idx < clusters_.size()); + return clusters_[cluster_idx]; } - const Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) const { - HWY_ASSERT(pkg_idx < NumPackages()); - const Package& package = packages_[pkg_idx]; - HWY_ASSERT(cluster_idx < package.clusters.size()); - return package.clusters[cluster_idx]; - } - Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) { - HWY_ASSERT(pkg_idx < NumPackages()); - Package& package = packages_[pkg_idx]; - HWY_ASSERT(cluster_idx < package.clusters.size()); - return package.clusters[cluster_idx]; - } - -#if !GEMMA_DISABLE_TOPOLOGY - const hwy::Topology& FullTopology() const { return topology_; } -#endif // In case we are running with a subset of packages/clusters, these are added // to the package/cluster indices for purposes of the thread name, so that @@ -155,26 +139,17 @@ class BoundedTopology { size_t SkippedClusters() const { return cluster_slice_.Begin(); } private: - struct Package { - explicit Package(const LPS& enabled_lps); - Package(const LPS& enabled_lps, const hwy::Topology& topology, - size_t pkg_idx, BoundedSlice cluster_slice); - - // For SortByDescendingSize. - size_t Size() const { return clusters.size(); } - - std::vector clusters; - }; // Package - - void InitFromTopology(const LPS& enabled_lps); + void SplitLargeCluster(const LPS& enabled_lps, + hwy::Topology::Cluster tcluster); + bool InitFromTopology(const LPS& enabled_lps); void InitFromLPs(const LPS& enabled_lps); #if !GEMMA_DISABLE_TOPOLOGY hwy::Topology topology_; #endif - BoundedSlice package_slice_; + BoundedSlice package_slice_; // Within the entire detected topology. BoundedSlice cluster_slice_; - std::vector packages_; + std::vector clusters_; char topology_string_[96]; LPS nodes_; };