Make attention configurable on the command line.

PiperOrigin-RevId: 842760721
This commit is contained in:
Martin Stolle 2025-12-10 09:33:21 -08:00 committed by Copybara-Service
parent 2441ff01bf
commit 78deacc357
3 changed files with 13 additions and 0 deletions

View File

@ -712,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
}
}
AttentionImpl GetAttentionImpl(const std::string& impl) {
if (impl == "old") return AttentionImpl::kOld;
if (impl == "flash") return AttentionImpl::kFlash;
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
return AttentionImpl::kOld;
}
} // namespace gcpp

View File

@ -85,6 +85,8 @@ enum class AttentionImpl {
kFlash,
};
AttentionImpl GetAttentionImpl(const std::string& impl);
/*
* Returns a bitmask of flags to pass to attention functions based on the
* attention implementation selected.

View File

@ -194,6 +194,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
// For prompts longer than the Linux terminal's 4K line edit buffer.
Path prompt_file;
std::string eot_line;
std::string attention_impl;
template <class Visitor>
void ForEach(const Visitor& visitor) {
@ -247,6 +248,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
2);
visitor(attention_impl, "attention_impl", std::string("flash"),
"Attention implementation to use. See configs.cc for options.", 2);
}
void CopyTo(RuntimeConfig& runtime_config) const {
@ -268,6 +271,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.temperature = temperature;
runtime_config.top_k = top_k;
runtime_config.attention_impl = GetAttentionImpl(attention_impl);
}
};