diff --git a/gemma/configs.cc b/gemma/configs.cc index 70fd8fa..cb508e8 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -712,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { } } +AttentionImpl GetAttentionImpl(const std::string& impl) { + if (impl == "old") return AttentionImpl::kOld; + if (impl == "flash") return AttentionImpl::kFlash; + HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); + return AttentionImpl::kOld; +} + } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index 5de74bf..447c246 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -85,6 +85,8 @@ enum class AttentionImpl { kFlash, }; +AttentionImpl GetAttentionImpl(const std::string& impl); + /* * Returns a bitmask of flags to pass to attention functions based on the * attention implementation selected. diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 78e2208..0db32d3 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -194,6 +194,7 @@ struct InferenceArgs : public ArgsBase { // For prompts longer than the Linux terminal's 4K line edit buffer. Path prompt_file; std::string eot_line; + std::string attention_impl; template void ForEach(const Visitor& visitor) { @@ -247,6 +248,8 @@ struct InferenceArgs : public ArgsBase { "before the line where only the given string appears.\n Default = " "When a newline is encountered, that signals the end of the turn.", 2); + visitor(attention_impl, "attention_impl", std::string("flash"), + "Attention implementation to use. See configs.cc for options.", 2); } void CopyTo(RuntimeConfig& runtime_config) const { @@ -268,6 +271,7 @@ struct InferenceArgs : public ArgsBase { runtime_config.temperature = temperature; runtime_config.top_k = top_k; + runtime_config.attention_impl = GetAttentionImpl(attention_impl); } };