mirror of https://github.com/google/gemma.cpp.git
Make attention configurable on the command line.
PiperOrigin-RevId: 842760721
This commit is contained in:
parent
2441ff01bf
commit
78deacc357
|
|
@ -712,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AttentionImpl GetAttentionImpl(const std::string& impl) {
|
||||||
|
if (impl == "old") return AttentionImpl::kOld;
|
||||||
|
if (impl == "flash") return AttentionImpl::kFlash;
|
||||||
|
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
|
||||||
|
return AttentionImpl::kOld;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,8 @@ enum class AttentionImpl {
|
||||||
kFlash,
|
kFlash,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
AttentionImpl GetAttentionImpl(const std::string& impl);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Returns a bitmask of flags to pass to attention functions based on the
|
* Returns a bitmask of flags to pass to attention functions based on the
|
||||||
* attention implementation selected.
|
* attention implementation selected.
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
// For prompts longer than the Linux terminal's 4K line edit buffer.
|
// For prompts longer than the Linux terminal's 4K line edit buffer.
|
||||||
Path prompt_file;
|
Path prompt_file;
|
||||||
std::string eot_line;
|
std::string eot_line;
|
||||||
|
std::string attention_impl;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
|
@ -247,6 +248,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"before the line where only the given string appears.\n Default = "
|
"before the line where only the given string appears.\n Default = "
|
||||||
"When a newline is encountered, that signals the end of the turn.",
|
"When a newline is encountered, that signals the end of the turn.",
|
||||||
2);
|
2);
|
||||||
|
visitor(attention_impl, "attention_impl", std::string("flash"),
|
||||||
|
"Attention implementation to use. See configs.cc for options.", 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||||
|
|
@ -268,6 +271,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
runtime_config.temperature = temperature;
|
runtime_config.temperature = temperature;
|
||||||
runtime_config.top_k = top_k;
|
runtime_config.top_k = top_k;
|
||||||
|
runtime_config.attention_impl = GetAttentionImpl(attention_impl);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue