mirror of https://github.com/google/gemma.cpp.git
57 lines
1.9 KiB
C++
57 lines
1.9 KiB
C++
#include "gemma/configs.h"
|
|
|
|
#include <stdio.h>
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "compression/types.h" // Type
|
|
#include "io/fields.h" // Type
|
|
|
|
namespace gcpp {
|
|
|
|
TEST(ConfigsTest, TestAll) {
|
|
ForEachModel([&](Model model) {
|
|
ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
|
|
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
|
|
config.Specifier().c_str());
|
|
HWY_ASSERT(config.model == model);
|
|
|
|
// We can deduce the model/display_name from all other fields.
|
|
config.model = Model::UNKNOWN;
|
|
const std::string saved_display_name = config.display_name;
|
|
config.display_name.clear();
|
|
HWY_ASSERT(config.OverwriteWithCanonical());
|
|
HWY_ASSERT(config.model == model);
|
|
HWY_ASSERT(config.display_name == saved_display_name);
|
|
|
|
const std::vector<uint32_t> serialized = config.Write();
|
|
ModelConfig deserialized;
|
|
const IFields::ReadResult result =
|
|
deserialized.Read(hwy::Span<const uint32_t>(serialized), /*pos=*/0);
|
|
HWY_ASSERT(result.pos == serialized.size());
|
|
// We wrote it, so all fields should be known, and no extra.
|
|
HWY_ASSERT(result.extra_u32 == 0);
|
|
HWY_ASSERT(result.missing_fields == 0);
|
|
// All fields should match.
|
|
HWY_ASSERT(deserialized.TestEqual(config, /*print=*/true));
|
|
HWY_ASSERT(deserialized.model == model);
|
|
HWY_ASSERT(deserialized.display_name == saved_display_name);
|
|
});
|
|
}
|
|
|
|
TEST(ConfigsTest, TestAttentionImpl) {
|
|
for (int i = 0; i < static_cast<int>(AttentionImpl::kSentinel); ++i) {
|
|
AttentionImpl impl = static_cast<AttentionImpl>(i);
|
|
std::string name = GetAttentionImplName(impl);
|
|
ASSERT_NE(name, "unknown");
|
|
ASSERT_EQ(GetAttentionImpl(name), impl);
|
|
}
|
|
ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown");
|
|
ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld);
|
|
ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld);
|
|
}
|
|
|
|
} // namespace gcpp
|