From 78deacc357ef224d5493e071d478659701e66f47 Mon Sep 17 00:00:00 2001 From: Martin Stolle Date: Wed, 10 Dec 2025 09:33:21 -0800 Subject: [PATCH] Make attention configurable on the command line. PiperOrigin-RevId: 842760721 --- gemma/configs.cc | 7 +++++++ gemma/configs.h | 2 ++ gemma/gemma_args.h | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/gemma/configs.cc b/gemma/configs.cc index 70fd8fa..cb508e8 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -712,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { } } +AttentionImpl GetAttentionImpl(const std::string& impl) { + if (impl == "old") return AttentionImpl::kOld; + if (impl == "flash") return AttentionImpl::kFlash; + HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); + return AttentionImpl::kOld; +} + } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index 5de74bf..447c246 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -85,6 +85,8 @@ enum class AttentionImpl { kFlash, }; +AttentionImpl GetAttentionImpl(const std::string& impl); + /* * Returns a bitmask of flags to pass to attention functions based on the * attention implementation selected. diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 78e2208..0db32d3 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -194,6 +194,7 @@ struct InferenceArgs : public ArgsBase { // For prompts longer than the Linux terminal's 4K line edit buffer. Path prompt_file; std::string eot_line; + std::string attention_impl; template void ForEach(const Visitor& visitor) { @@ -247,6 +248,8 @@ struct InferenceArgs : public ArgsBase { "before the line where only the given string appears.\n Default = " "When a newline is encountered, that signals the end of the turn.", 2); + visitor(attention_impl, "attention_impl", std::string("flash"), + "Attention implementation to use. See configs.cc for options.", 2); } void CopyTo(RuntimeConfig& runtime_config) const { @@ -268,6 +271,7 @@ struct InferenceArgs : public ArgsBase { runtime_config.temperature = temperature; runtime_config.top_k = top_k; + runtime_config.attention_impl = GetAttentionImpl(attention_impl); } };