diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b0b832..2c6a74d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,6 +144,7 @@ set(GEMMA_TEST_FILES backprop/backward_scalar_test.cc backprop/backward_test.cc backprop/optimize_test.cc + compression/blob_store_test.cc compression/compress_test.cc compression/distortion_test.cc compression/nuq_test.cc diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index e832793..b6508cd 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -48,6 +48,19 @@ cc_library( ], ) +cc_test( + name = "blob_store_test", + srcs = ["blob_store_test.cc"], + deps = [ + ":blob_store", + ":io", + "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:thread_pool", + ], +) + cc_library( name = "distortion", hdrs = [ diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 57f50f5..0ce9996 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -288,6 +288,14 @@ BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data, uint64_t offset; size_t actual_size; if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; + if (actual_size != size) { + fprintf(stderr, + "Mismatch between expected %d and actual %d KiB size of blob %s. " + "Please see README.md on how to update the weights.\n", + static_cast(size >> 10), static_cast(actual_size >> 10), + StringFromKey(key).c_str()); + return __LINE__; + } if (!file_->Read(offset, actual_size, data)) { return __LINE__; } diff --git a/compression/blob_store_test.cc b/compression/blob_store_test.cc new file mode 100644 index 0000000..6464756 --- /dev/null +++ b/compression/blob_store_test.cc @@ -0,0 +1,81 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compression/blob_store.h" + +#include + +#include +#include + +#include "compression/io.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ + +namespace gcpp { +namespace { + +#if !HWY_TEST_STANDALONE +class BlobStoreTest : public testing::Test {}; +#endif + +#if !HWY_OS_WIN +TEST(BlobStoreTest, TestReadWrite) { + static const std::array kOriginalData = {-1, 0, 3.14159, 2.71828}; + + // mkstemp will modify path_str so it holds a newly-created temporary file. + char path_str[] = "/tmp/blob_store_test.sbs-XXXXXX"; + const int fd = mkstemp(path_str); + HWY_ASSERT(fd > 0); + + hwy::ThreadPool pool(4); + const Path path(path_str); + std::array buffer = kOriginalData; + + const hwy::uint128_t keyA = MakeKey("0123456789abcdef"); + const hwy::uint128_t keyB = MakeKey("q"); + BlobWriter writer; + writer.Add(keyA, "DATA", 5); + writer.Add(keyB, buffer.data(), sizeof(buffer)); + HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0); + HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size()); + + std::fill(buffer.begin(), buffer.end(), 0); + BlobReader reader; + HWY_ASSERT_EQ(reader.Open(path), 0); + HWY_ASSERT_EQ(reader.BlobSize(keyA), 5); + HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer)); + + HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0); + HWY_ASSERT_EQ(reader.ReadAll(pool), 0); + HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size()); + + { + std::array buffer; + HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0); + HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0); + HWY_ASSERT_STRING_EQ("DATA", buffer.data()); + } + + close(fd); + unlink(path_str); +} +#endif + +} // namespace +} // namespace gcpp + +HWY_TEST_MAIN();