mirror of https://github.com/google/gemma.cpp.git
Further speed up blob_compare: single alloc, use dual sockets
PiperOrigin-RevId: 724947361
This commit is contained in:
parent
b18bd781f6
commit
b0fe9a43e6
|
|
@ -275,9 +275,11 @@ cc_binary(
|
|||
":blob_store",
|
||||
":io",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:threading",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,23 +18,25 @@
|
|||
#include <string.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // IndexRange
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Returns whether it makes sense to continue comparing.
|
||||
using KeySpan = hwy::Span<const hwy::uint128_t>;
|
||||
|
||||
// Returns false if any keys differ, because then blobs are not comparable.
|
||||
bool CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
|
||||
hwy::Span<const hwy::uint128_t> keys1 = reader1.Keys();
|
||||
hwy::Span<const hwy::uint128_t> keys2 = reader2.Keys();
|
||||
KeySpan keys1 = reader1.Keys();
|
||||
KeySpan keys2 = reader2.Keys();
|
||||
if (keys1.size() != keys2.size()) {
|
||||
fprintf(stderr, "#keys mismatch: %zu vs %zu\n", keys1.size(), keys2.size());
|
||||
return false;
|
||||
|
|
@ -50,25 +52,35 @@ bool CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
|
|||
return true;
|
||||
}
|
||||
|
||||
using BlobMap = std::map<hwy::uint128_t, std::vector<uint8_t>>;
|
||||
|
||||
size_t TotalBytes(hwy::Span<const hwy::uint128_t>& keys, BlobReader& reader) {
|
||||
// Total amount to allocate for all blobs.
|
||||
size_t TotalBytes(BlobReader& reader) {
|
||||
size_t total_bytes = 0;
|
||||
for (const hwy::uint128_t key : keys) {
|
||||
for (const hwy::uint128_t key : reader.Keys()) {
|
||||
total_bytes += reader.BlobSize(key);
|
||||
}
|
||||
return total_bytes;
|
||||
}
|
||||
|
||||
void ParallelRead(BlobReader& reader, BlobMap& blobs, hwy::ThreadPool& pool) {
|
||||
hwy::Span<const hwy::uint128_t> keys = reader.Keys();
|
||||
for (const hwy::uint128_t key : keys) {
|
||||
const auto ib = blobs.insert({key, {}});
|
||||
HWY_ASSERT(ib.second); // newly inserted, no duplicate keys
|
||||
using BytePtr = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||
using ByteSpan = hwy::Span<uint8_t>; // Sections within BytePtr
|
||||
using BlobVec = std::vector<ByteSpan>; // in order of keys
|
||||
|
||||
// Allocates memory within the single allocation and updates `pos`.
|
||||
BlobVec ReserveMemory(BlobReader& reader, BytePtr& all_blobs, size_t& pos) {
|
||||
BlobVec blobs;
|
||||
for (const hwy::uint128_t key : reader.Keys()) {
|
||||
const size_t bytes = reader.BlobSize(key);
|
||||
// TODO: AllocateAligned instead, avoids initializing the memory.
|
||||
ib.first->second.resize(bytes);
|
||||
reader.Enqueue(key, ib.first->second.data(), bytes);
|
||||
blobs.push_back(ByteSpan(all_blobs.get() + pos, bytes));
|
||||
pos += bytes;
|
||||
}
|
||||
return blobs;
|
||||
}
|
||||
|
||||
// Reads one set of blobs in parallel (helpful if in disk cache).
|
||||
void ReadBlobs(BlobReader& reader, BlobVec& blobs, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
for (size_t i = 0; i < blobs.size(); ++i) {
|
||||
reader.Enqueue(reader.Keys()[i], blobs[i].data(), blobs[i].size());
|
||||
}
|
||||
const BlobError err = reader.ReadAll(pool);
|
||||
if (err != 0) {
|
||||
|
|
@ -76,13 +88,25 @@ void ParallelRead(BlobReader& reader, BlobMap& blobs, hwy::ThreadPool& pool) {
|
|||
}
|
||||
}
|
||||
|
||||
// Parallelizes ReadBlobs across (two) packages, if available.
|
||||
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, size_t total_bytes,
|
||||
BlobVec& blobs1, BlobVec& blobs2, NestedPools& pools) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
fprintf(stderr, "Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
|
||||
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
|
||||
pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) {
|
||||
ReadBlobs(task ? reader2 : reader1, task ? blobs2 : blobs1,
|
||||
pools.Pool(pkg_idx));
|
||||
});
|
||||
const double t1 = hwy::platform::Now();
|
||||
fprintf(stderr, "%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
|
||||
}
|
||||
|
||||
// Returns number of elements with a mismatch. For float and bf16 blobs, uses
|
||||
// L1 and relative error, otherwise byte-wise comparison.
|
||||
size_t BlobDifferences(BlobMap& blobs1, BlobMap& blobs2,
|
||||
size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
||||
const hwy::uint128_t key) {
|
||||
std::vector<uint8_t>& data1 = blobs1[key];
|
||||
std::vector<uint8_t>& data2 = blobs2[key];
|
||||
if (data1.size() != data2.size() || data1.empty()) {
|
||||
if (data1.size() != data2.size() || data1.size() == 0) {
|
||||
HWY_ABORT("key %s size mismatch: %zu vs %zu\n", StringFromKey(key).c_str(),
|
||||
data1.size(), data2.size());
|
||||
}
|
||||
|
|
@ -133,8 +157,40 @@ size_t BlobDifferences(BlobMap& blobs1, BlobMap& blobs2,
|
|||
return mismatches;
|
||||
}
|
||||
|
||||
void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2,
|
||||
size_t total_bytes, NestedPools& pools) {
|
||||
fprintf(stderr, "Comparing %zu blobs in parallel: ", keys.size());
|
||||
const double t0 = hwy::platform::Now();
|
||||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
const IndexRangePartition ranges = StaticPartition(
|
||||
IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1);
|
||||
ParallelizeOneRange(
|
||||
ranges, pools.AllPackages(),
|
||||
[&](const IndexRange& range, size_t pkg_idx) {
|
||||
pools.Pool(pkg_idx).Run(
|
||||
range.begin(), range.end(), [&](size_t i, size_t /*thread*/) {
|
||||
const size_t mismatches =
|
||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||
if (mismatches != 0) {
|
||||
fprintf(stderr, "key %s has %zu mismatches in %zu bytes!\n",
|
||||
StringFromKey(keys[i]).c_str(), mismatches,
|
||||
blobs1[i].size());
|
||||
blobs_diff.fetch_add(1);
|
||||
} else {
|
||||
blobs_equal.fetch_add(1);
|
||||
}
|
||||
});
|
||||
});
|
||||
const double t1 = hwy::platform::Now();
|
||||
fprintf(stderr, "%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
|
||||
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
|
||||
blobs_diff.load());
|
||||
}
|
||||
|
||||
// Compares two sbs files, including blob order.
|
||||
void CompareBlobs(const char* path1, const char* path2) {
|
||||
void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
||||
// Open files.
|
||||
BlobReader reader1;
|
||||
BlobReader reader2;
|
||||
const BlobError err1 = reader1.Open(Path(path1));
|
||||
|
|
@ -142,33 +198,21 @@ void CompareBlobs(const char* path1, const char* path2) {
|
|||
if (err1 != 0 || err2 != 0) {
|
||||
HWY_ABORT("Failed to open files: %s %s: %d %d\n", path1, path2, err1, err2);
|
||||
}
|
||||
|
||||
if (!CompareKeys(reader1, reader2)) return;
|
||||
|
||||
// Single allocation, avoid initializing the memory.
|
||||
NestedPools pools(0);
|
||||
Allocator::Init(pools.Topology());
|
||||
hwy::Span<const hwy::uint128_t> keys1 = reader1.Keys();
|
||||
BlobMap blobs1, blobs2;
|
||||
fprintf(stderr, "Reading 2x %zu GiB, %zu cores...\n",
|
||||
TotalBytes(keys1, reader1) >> 30, pools.Pool().NumWorkers());
|
||||
ParallelRead(reader1, blobs1, pools.Pool());
|
||||
ParallelRead(reader2, blobs2, pools.Pool());
|
||||
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
|
||||
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
|
||||
size_t pos = 0;
|
||||
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
|
||||
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
|
||||
|
||||
fprintf(stderr, "Comparing %zu blobs in parallel...\n", keys1.size());
|
||||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
pools.Pool().Run(0, keys1.size(), [&](size_t i, size_t /*thread*/) {
|
||||
const size_t mismatches = BlobDifferences(blobs1, blobs2, keys1[i]);
|
||||
if (mismatches != 0) {
|
||||
fprintf(stderr, "key %s has %zu mismatches in %zu bytes!\n",
|
||||
StringFromKey(keys1[i]).c_str(), mismatches,
|
||||
reader1.BlobSize(keys1[i]));
|
||||
blobs_diff.fetch_add(1);
|
||||
} else {
|
||||
blobs_equal.fetch_add(1);
|
||||
}
|
||||
});
|
||||
fprintf(stderr, "Total blob matches=%zu, mismatches=%zu\n",
|
||||
blobs_equal.load(), blobs_diff.load());
|
||||
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
|
||||
|
||||
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -180,6 +224,6 @@ int main(int argc, char** argv) {
|
|||
if (strcmp(argv[1], argv[2]) == 0) {
|
||||
HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]);
|
||||
}
|
||||
gcpp::CompareBlobs(argv[1], argv[2]);
|
||||
gcpp::ReadAndCompareBlobs(argv[1], argv[2]);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue