mirror of https://github.com/google/gemma.cpp.git
Use a struct to manage the mapping between `AttentionImpl` enum values and their string names, simplifying `GetAttentionImplName` function. Add a test to ensure all valid `AttentionImpl` enums have a corresponding name and can be looked up.
PiperOrigin-RevId: 876124604
This commit is contained in:
parent
c6587efe70
commit
d8a123e4ec
|
|
@ -19,6 +19,7 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // Type
|
||||
|
|
@ -678,7 +679,7 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
|||
return Model::GEMMA3_270M;
|
||||
|
||||
case 26:
|
||||
if (layer_types & (kDeducedViT|kDeducedKqNorm)) {
|
||||
if (layer_types & (kDeducedViT | kDeducedKqNorm)) {
|
||||
return Model::GEMMA3_1B;
|
||||
}
|
||||
return Model::GEMMA2_2B;
|
||||
|
|
@ -712,22 +713,26 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
|||
}
|
||||
}
|
||||
|
||||
// Keep in sync with enum class AttentionImpl.
|
||||
const char* kAttentionImplNames[] = {
|
||||
"old", "flash",
|
||||
"unknown" // keep last
|
||||
constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
|
||||
{"old", AttentionImpl::kOld},
|
||||
{"flash", AttentionImpl::kFlash},
|
||||
{"flash_transposed_qs", AttentionImpl::kFlashTransposedQs},
|
||||
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
|
||||
};
|
||||
|
||||
std::string GetAttentionImplName(AttentionImpl impl) {
|
||||
return kAttentionImplNames[static_cast<size_t>(impl)];
|
||||
for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) {
|
||||
if (attention_impl == impl) return std::string(name);
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
AttentionImpl GetAttentionImpl(const std::string& impl) {
|
||||
if (impl == GetAttentionImplName(AttentionImpl::kOld))
|
||||
return AttentionImpl::kOld;
|
||||
if (impl == GetAttentionImplName(AttentionImpl::kFlash))
|
||||
return AttentionImpl::kFlash;
|
||||
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str());
|
||||
AttentionImpl GetAttentionImpl(const std::string& impl_name) {
|
||||
for (const auto& [name, attention_impl] : kAttentionImplNameToEnum) {
|
||||
if (name == impl_name) return attention_impl;
|
||||
}
|
||||
HWY_WARN("Unknown attention implementation: %s. Using kOld.\n",
|
||||
impl_name.c_str());
|
||||
return AttentionImpl::kOld;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -41,4 +41,16 @@ TEST(ConfigsTest, TestAll) {
|
|||
});
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, TestAttentionImpl) {
|
||||
for (int i = 0; i < static_cast<int>(AttentionImpl::kSentinel); ++i) {
|
||||
AttentionImpl impl = static_cast<AttentionImpl>(i);
|
||||
std::string name = GetAttentionImplName(impl);
|
||||
ASSERT_NE(name, "unknown");
|
||||
ASSERT_EQ(GetAttentionImpl(name), impl);
|
||||
}
|
||||
ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown");
|
||||
ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld);
|
||||
ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue