Optimize Image::GetPatch() to copy rows instead of pixels at a time.

PiperOrigin-RevId: 767436146
This commit is contained in:
The gemma.cpp Authors 2025-06-04 22:30:34 -07:00 committed by Copybara-Service
parent eff0213e88
commit dd7d4a7717
1 changed files with 24 additions and 21 deletions

View File

@ -22,6 +22,7 @@
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <iostream>
#include <limits>
@ -83,7 +84,7 @@ const char* ParseUnsigned(const char* pos, const char* end, size_t& num) {
return nullptr;
}
num = 0;
for ( ; pos < end && std::isdigit(*pos); ++pos) {
for (; pos < end && std::isdigit(*pos); ++pos) {
num *= 10;
num += *pos - '0';
}
@ -211,30 +212,32 @@ bool Image::WriteBinary(const std::string& filename) const {
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
void Image::GetPatch(size_t patch_num, float* patch) const {
PROFILER_FUNC;
const size_t kDataSize = width_ * height_ * 3;
constexpr size_t kNumChannels = 3;
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float));
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel);
const size_t kDataSize = width_ * height_ * kNumChannels;
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
HWY_ASSERT(size() == kDataSize);
HWY_ASSERT(width_ % kPatchSize == 0);
HWY_ASSERT(height_ % kPatchSize == 0);
const size_t kNumPatchesPerRow = width_ / kPatchSize;
size_t i_offs = patch_num / kNumPatchesPerRow;
size_t j_offs = patch_num % kNumPatchesPerRow;
HWY_ASSERT(0 <= i_offs && i_offs < height_ / kPatchSize);
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatchesPerRow);
i_offs *= kPatchSize;
j_offs *= kPatchSize;
// This can be made faster, but let's first see whether it matters.
const float* image_data = data();
for (size_t i = 0; i < kPatchSize; ++i) {
for (size_t j = 0; j < kPatchSize; ++j) {
for (size_t k = 0; k < 3; ++k) {
const size_t patch_index = (i * kPatchSize + j) * 3 + k;
HWY_DASSERT(patch_index < kPatchSize * kPatchSize * 3);
const size_t image_index =
((i + i_offs) * width_ + (j + j_offs)) * 3 + k;
HWY_DASSERT(image_index < kDataSize);
patch[patch_index] = image_data[image_index];
}
}
size_t patch_y = patch_num / kNumPatchesPerRow;
size_t patch_x = patch_num % kNumPatchesPerRow;
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize);
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow);
patch_y *= kPatchSize;
patch_x *= kPatchSize;
// Move `out` and `in` to the start of the patch.
char* out = reinterpret_cast<char*>(patch);
const char* in = reinterpret_cast<const char*>(data());
in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
// Copy the patch one row at a time.
for (size_t y = 0; y < kPatchSize; ++y) {
std::memcpy(out, in, kBytesPerRow);
out += kBytesPerRow;
in += in_bytes_to_next_row;
}
}