mirror of https://github.com/google/gemma.cpp.git
New blob_store_test, ensure ReadOne checks actual size against requested size
PiperOrigin-RevId: 688974390
This commit is contained in:
parent
91bf2317ff
commit
4197d69dfc
|
|
@ -144,6 +144,7 @@ set(GEMMA_TEST_FILES
|
||||||
backprop/backward_scalar_test.cc
|
backprop/backward_scalar_test.cc
|
||||||
backprop/backward_test.cc
|
backprop/backward_test.cc
|
||||||
backprop/optimize_test.cc
|
backprop/optimize_test.cc
|
||||||
|
compression/blob_store_test.cc
|
||||||
compression/compress_test.cc
|
compression/compress_test.cc
|
||||||
compression/distortion_test.cc
|
compression/distortion_test.cc
|
||||||
compression/nuq_test.cc
|
compression/nuq_test.cc
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,19 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "blob_store_test",
|
||||||
|
srcs = ["blob_store_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":blob_store",
|
||||||
|
":io",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "distortion",
|
name = "distortion",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
|
|
||||||
|
|
@ -288,6 +288,14 @@ BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data,
|
||||||
uint64_t offset;
|
uint64_t offset;
|
||||||
size_t actual_size;
|
size_t actual_size;
|
||||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||||
|
if (actual_size != size) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"Mismatch between expected %d and actual %d KiB size of blob %s. "
|
||||||
|
"Please see README.md on how to update the weights.\n",
|
||||||
|
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
|
||||||
|
StringFromKey(key).c_str());
|
||||||
|
return __LINE__;
|
||||||
|
}
|
||||||
if (!file_->Read(offset, actual_size, data)) {
|
if (!file_->Read(offset, actual_size, data)) {
|
||||||
return __LINE__;
|
return __LINE__;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "compression/blob_store.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if !HWY_TEST_STANDALONE
|
||||||
|
class BlobStoreTest : public testing::Test {};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if !HWY_OS_WIN
|
||||||
|
TEST(BlobStoreTest, TestReadWrite) {
|
||||||
|
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||||
|
|
||||||
|
// mkstemp will modify path_str so it holds a newly-created temporary file.
|
||||||
|
char path_str[] = "/tmp/blob_store_test.sbs-XXXXXX";
|
||||||
|
const int fd = mkstemp(path_str);
|
||||||
|
HWY_ASSERT(fd > 0);
|
||||||
|
|
||||||
|
hwy::ThreadPool pool(4);
|
||||||
|
const Path path(path_str);
|
||||||
|
std::array<float, 4> buffer = kOriginalData;
|
||||||
|
|
||||||
|
const hwy::uint128_t keyA = MakeKey("0123456789abcdef");
|
||||||
|
const hwy::uint128_t keyB = MakeKey("q");
|
||||||
|
BlobWriter writer;
|
||||||
|
writer.Add(keyA, "DATA", 5);
|
||||||
|
writer.Add(keyB, buffer.data(), sizeof(buffer));
|
||||||
|
HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0);
|
||||||
|
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||||
|
|
||||||
|
std::fill(buffer.begin(), buffer.end(), 0);
|
||||||
|
BlobReader reader;
|
||||||
|
HWY_ASSERT_EQ(reader.Open(path), 0);
|
||||||
|
HWY_ASSERT_EQ(reader.BlobSize(keyA), 5);
|
||||||
|
HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer));
|
||||||
|
|
||||||
|
HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0);
|
||||||
|
HWY_ASSERT_EQ(reader.ReadAll(pool), 0);
|
||||||
|
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||||
|
|
||||||
|
{
|
||||||
|
std::array<char, 5> buffer;
|
||||||
|
HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0);
|
||||||
|
HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0);
|
||||||
|
HWY_ASSERT_STRING_EQ("DATA", buffer.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
close(fd);
|
||||||
|
unlink(path_str);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
HWY_TEST_MAIN();
|
||||||
Loading…
Reference in New Issue