Use a struct to manage the mapping between `AttentionImpl` enum values and their string names, simplifying `GetAttentionImplName` function. Add a test to ensure all valid `AttentionImpl` enums have a corresponding name and can be looked up.

PiperOrigin-RevId: 876124604
This commit is contained in:
Viktor Shipitsin 2026-02-27 01:30:40 -08:00 committed by Copybara-Service
parent c6587efe70
commit d8a123e4ec
2 changed files with 29 additions and 12 deletions

View File

@ -19,6 +19,7 @@
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "compression/types.h" // Type #include "compression/types.h" // Type
@ -678,7 +679,7 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
return Model::GEMMA3_270M; return Model::GEMMA3_270M;
case 26: case 26:
if (layer_types & (kDeducedViT|kDeducedKqNorm)) { if (layer_types & (kDeducedViT | kDeducedKqNorm)) {
return Model::GEMMA3_1B; return Model::GEMMA3_1B;
} }
return Model::GEMMA2_2B; return Model::GEMMA2_2B;
@ -712,22 +713,26 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
} }
} }
// Keep in sync with enum class AttentionImpl. constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
const char* kAttentionImplNames[] = { {"old", AttentionImpl::kOld},
"old", "flash", {"flash", AttentionImpl::kFlash},
"unknown" // keep last {"flash_transposed_qs", AttentionImpl::kFlashTransposedQs},
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
}; };
std::string GetAttentionImplName(AttentionImpl impl) { std::string GetAttentionImplName(AttentionImpl impl) {
return kAttentionImplNames[static_cast<size_t>(impl)]; for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) {
if (attention_impl == impl) return std::string(name);
}
return "unknown";
} }
AttentionImpl GetAttentionImpl(const std::string& impl) { AttentionImpl GetAttentionImpl(const std::string& impl_name) {
if (impl == GetAttentionImplName(AttentionImpl::kOld)) for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) {
return AttentionImpl::kOld; if (name == impl_name) return attention_impl;
if (impl == GetAttentionImplName(AttentionImpl::kFlash)) }
return AttentionImpl::kFlash; HWY_WARN("Unknown attention implementation: %s. Using kOld.\n",
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); impl_name.c_str());
return AttentionImpl::kOld; return AttentionImpl::kOld;
} }

View File

@ -41,4 +41,16 @@ TEST(ConfigsTest, TestAll) {
}); });
} }
TEST(ConfigsTest, TestAttentionImpl) {
for (int i = 0; i < static_cast<int>(AttentionImpl::kSentinel); ++i) {
AttentionImpl impl = static_cast<AttentionImpl>(i);
std::string name = GetAttentionImplName(impl);
ASSERT_NE(name, "unknown");
ASSERT_EQ(GetAttentionImpl(name), impl);
}
ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown");
ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld);
ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld);
}
} // namespace gcpp } // namespace gcpp