diff --git a/paligemma/image.cc b/paligemma/image.cc index 6c95437..f5bf4f2 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -83,7 +84,7 @@ const char* ParseUnsigned(const char* pos, const char* end, size_t& num) { return nullptr; } num = 0; - for ( ; pos < end && std::isdigit(*pos); ++pos) { + for (; pos < end && std::isdigit(*pos); ++pos) { num *= 10; num += *pos - '0'; } @@ -211,30 +212,32 @@ bool Image::WriteBinary(const std::string& filename) const { // We want the N-th patch of size kPatchSize x kPatchSize x 3. void Image::GetPatch(size_t patch_num, float* patch) const { PROFILER_FUNC; - const size_t kDataSize = width_ * height_ * 3; + constexpr size_t kNumChannels = 3; + constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float)); + constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel); + const size_t kDataSize = width_ * height_ * kNumChannels; + const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel); HWY_ASSERT(size() == kDataSize); HWY_ASSERT(width_ % kPatchSize == 0); HWY_ASSERT(height_ % kPatchSize == 0); const size_t kNumPatchesPerRow = width_ / kPatchSize; - size_t i_offs = patch_num / kNumPatchesPerRow; - size_t j_offs = patch_num % kNumPatchesPerRow; - HWY_ASSERT(0 <= i_offs && i_offs < height_ / kPatchSize); - HWY_ASSERT(0 <= j_offs && j_offs < kNumPatchesPerRow); - i_offs *= kPatchSize; - j_offs *= kPatchSize; - // This can be made faster, but let's first see whether it matters. - const float* image_data = data(); - for (size_t i = 0; i < kPatchSize; ++i) { - for (size_t j = 0; j < kPatchSize; ++j) { - for (size_t k = 0; k < 3; ++k) { - const size_t patch_index = (i * kPatchSize + j) * 3 + k; - HWY_DASSERT(patch_index < kPatchSize * kPatchSize * 3); - const size_t image_index = - ((i + i_offs) * width_ + (j + j_offs)) * 3 + k; - HWY_DASSERT(image_index < kDataSize); - patch[patch_index] = image_data[image_index]; - } - } + size_t patch_y = patch_num / kNumPatchesPerRow; + size_t patch_x = patch_num % kNumPatchesPerRow; + HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize); + HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow); + patch_y *= kPatchSize; + patch_x *= kPatchSize; + + // Move `out` and `in` to the start of the patch. + char* out = reinterpret_cast(patch); + const char* in = reinterpret_cast(data()); + in += (((patch_y * width_) + patch_x) * kBytesPerPixel); + + // Copy the patch one row at a time. + for (size_t y = 0; y < kPatchSize; ++y) { + std::memcpy(out, in, kBytesPerRow); + out += kBytesPerRow; + in += in_bytes_to_next_row; } }