Use a struct to manage the mapping between `AttentionImpl` enum values and their string names, simplifying `GetAttentionImplName` function. Add a test to ensure all valid `AttentionImpl` enums have a corresponding name and can be looked up.

PiperOrigin-RevId: 876124604
This commit is contained in:
Viktor Shipitsin 2026-02-27 01:30:40 -08:00 committed by Copybara-Service
parent c6587efe70
commit d8a123e4ec
2 changed files with 29 additions and 12 deletions

View File

@ -19,6 +19,7 @@
#include <stdio.h>
#include <string>
#include <utility>
#include <vector>
#include "compression/types.h" // Type
@ -678,7 +679,7 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
return Model::GEMMA3_270M;
case 26:
if (layer_types & (kDeducedViT|kDeducedKqNorm)) {
if (layer_types & (kDeducedViT | kDeducedKqNorm)) {
return Model::GEMMA3_1B;
}
return Model::GEMMA2_2B;
@ -712,22 +713,26 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
}
}
// Keep in sync with enum class AttentionImpl.
const char* kAttentionImplNames[] = {
"old", "flash",
"unknown" // keep last
constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
{"old", AttentionImpl::kOld},
{"flash", AttentionImpl::kFlash},
{"flash_transposed_qs", AttentionImpl::kFlashTransposedQs},
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
};
std::string GetAttentionImplName(AttentionImpl impl) {
return kAttentionImplNames[static_cast<size_t>(impl)];
for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) {
if (attention_impl == impl) return std::string(name);
}
return "unknown";
}
AttentionImpl GetAttentionImpl(const std::string& impl) {
if (impl == GetAttentionImplName(AttentionImpl::kOld))
return AttentionImpl::kOld;
if (impl == GetAttentionImplName(AttentionImpl::kFlash))
return AttentionImpl::kFlash;
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
AttentionImpl GetAttentionImpl(const std::string& impl_name) {
for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) {
if (name == impl_name) return attention_impl;
}
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n",
impl_name.c_str());
return AttentionImpl::kOld;
}

View File

@ -41,4 +41,16 @@ TEST(ConfigsTest, TestAll) {
});
}
TEST(ConfigsTest, TestAttentionImpl) {
for (int i = 0; i < static_cast<int>(AttentionImpl::kSentinel); ++i) {
AttentionImpl impl = static_cast<AttentionImpl>(i);
std::string name = GetAttentionImplName(impl);
ASSERT_NE(name, "unknown");
ASSERT_EQ(GetAttentionImpl(name), impl);
}
ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown");
ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld);
ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld);
}
} // namespace gcpp