From 3e18db17f427bb798e8e7f453cae23bb699612d6 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 7 Nov 2025 00:31:59 -0800 Subject: [PATCH] Avoid hard-coding kPatchSize. Thanks @Somet2mes for reporting. Fixes #762. PiperOrigin-RevId: 829308896 --- gemma/vit.cc | 9 +++++---- paligemma/BUILD.bazel | 1 + paligemma/image.cc | 39 +++++++++++++++++++-------------------- paligemma/image.h | 10 ++++++---- paligemma/image_test.cc | 11 +++++++---- 5 files changed, 38 insertions(+), 32 deletions(-) diff --git a/gemma/vit.cc b/gemma/vit.cc index b00efda..1be3123 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -295,19 +295,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, const size_t model_dim = model_config.vit_config.model_dim; const size_t patch_width = model_config.vit_config.patch_width; const size_t num_tokens = model_config.vit_config.seq_len; - const size_t patch_size = patch_width * patch_width * 3; + const size_t patch_area = patch_width * patch_width * 3; + const hwy::Divisor div_patch_dim(patch_width); HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); - HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); + HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area); HWY_DASSERT(activations.x.Cols() == model_dim); (void)model_dim; // img/embedding/kernel has original shape (14, 14, 3, 1152) // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3) // Must be padded, see `DoDecompressA`. - MatStorageT image_patches("patches", Extents2D(num_tokens, patch_size), + MatStorageT image_patches("patches", Extents2D(num_tokens, patch_area), env.ctx.allocator, MatPadding::kOdd); for (size_t i = 0; i < num_tokens; ++i) { - image.GetPatch(i, image_patches.Row(i)); + image.GetPatch(i, div_patch_dim, image_patches.Row(i)); } CallMatMul(image_patches, weights.vit_img_embedding_kernel, weights.vit_img_embedding_bias.PackedScale1(), env, activations.x); diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index 7a2e870..cc6c6e1 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -29,6 +29,7 @@ cc_test( deps = [ ":image", "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", ], ) diff --git a/paligemma/image.cc b/paligemma/image.cc index 20ecad8..d8b0cfc 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -37,8 +37,6 @@ namespace gcpp { namespace { -// Hardcoded for PaliGemma ViT input. -constexpr size_t kPatchSize = 14; // Returns the linearly scaled index in [0, to_size) closest to the // value in [0, from_size). @@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const { } // Image.data() is H x W x 3. -// We want the N-th patch of size kPatchSize x kPatchSize x 3. -void Image::GetPatch(size_t patch_num, float* patch) const { +// We want the N-th patch of size patch_dim x patch_dim x 3. +void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim, + float* patch) const { PROFILER_FUNC; constexpr size_t kNumChannels = 3; - constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float)); - constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel); - const size_t kDataSize = width_ * height_ * kNumChannels; + constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float); + const size_t patch_dim = div_patch_dim.GetDivisor(); + const size_t bytes_per_row = (patch_dim * kBytesPerPixel); const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel); - HWY_ASSERT(size() == kDataSize); - HWY_ASSERT(width_ % kPatchSize == 0); - HWY_ASSERT(height_ % kPatchSize == 0); - const size_t kNumPatchesPerRow = width_ / kPatchSize; - size_t patch_y = patch_num / kNumPatchesPerRow; - size_t patch_x = patch_num % kNumPatchesPerRow; - HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize); - HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow); - patch_y *= kPatchSize; - patch_x *= kPatchSize; + HWY_ASSERT(size() == width_ * height_ * kNumChannels); + HWY_ASSERT(div_patch_dim.Remainder(width_) == 0); + HWY_ASSERT(div_patch_dim.Remainder(height_) == 0); + const size_t patches_x = div_patch_dim.Divide(width_); + size_t patch_y = patch_num / patches_x; + size_t patch_x = patch_num % patches_x; + HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_)); + HWY_DASSERT(0 <= patch_x && patch_x < patches_x); + patch_y *= patch_dim; + patch_x *= patch_dim; // Move `out` and `in` to the start of the patch. char* out = reinterpret_cast(patch); @@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const { in += (((patch_y * width_) + patch_x) * kBytesPerPixel); // Copy the patch one row at a time. - for (size_t y = 0; y < kPatchSize; ++y) { - std::memcpy(out, in, kBytesPerRow); - out += kBytesPerRow; + for (size_t y = 0; y < patch_dim; ++y) { + std::memcpy(out, in, bytes_per_row); + out += bytes_per_row; in += in_bytes_to_next_row; } } diff --git a/paligemma/image.h b/paligemma/image.h index e0b1530..e54bf86 100644 --- a/paligemma/image.h +++ b/paligemma/image.h @@ -21,6 +21,7 @@ #include #include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" // Divisor namespace gcpp { @@ -44,11 +45,12 @@ class Image { bool WriteBinary(const std::string& filename) const; // Stores the patch for the given patch number in `patch`. // Patches are numbered in usual raster-order. E.g. for an image of size - // 224 x 224, there are 16 x 16 = 256 patches. - // `patch` should have space for at least 14 * 14 * 3 = 588 floats. + // 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches. + // `patch` should have space for at least patch_dim * patch_dim * 3. // Requires that Normalize() has been called and that the image width and - // height are multiples of 14. - void GetPatch(size_t patch_num, float* patch) const; + // height are multiples of patch_dim. + void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim, + float* patch) const; float *data() { return data_.data(); } const float *data() const { return data_.data(); } diff --git a/paligemma/image_test.cc b/paligemma/image_test.cc index e2c4bbf..3721363 100644 --- a/paligemma/image_test.cc +++ b/paligemma/image_test.cc @@ -20,6 +20,7 @@ #include #include "gtest/gtest.h" +#include "hwy/base.h" namespace gcpp { namespace { @@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) { EXPECT_EQ(image.data()[image.size() - 1], Normalize(122)); // Extract two patches. float patch[588]; - image.GetPatch(0, patch); + const hwy::Divisor div_patch_dim(14); + image.GetPatch(0, div_patch_dim, patch); EXPECT_EQ(patch[0], Normalize(160)); EXPECT_EQ(patch[1], Normalize(184)); EXPECT_EQ(patch[2], Normalize(188)); - image.GetPatch(18, patch); + image.GetPatch(18, div_patch_dim, patch); // Check the first row of the patch. for (size_t i = 0; i < 14 * 3; ++i) { EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]); @@ -108,14 +110,15 @@ TEST(ImageTest, Non224) { // Extract two patches. const size_t kPatchValues = 14 * 14 * 3; // = 588 float patch[kPatchValues]; + const hwy::Divisor div_patch_dim(14); // Patch 0 is just the "start" of the image. - image.GetPatch(0, patch); + image.GetPatch(0, div_patch_dim, patch); EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6); EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6); EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6); // The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its // pixel coordinates are offset by (14, 28). - image.GetPatch(6, patch); + image.GetPatch(6, div_patch_dim, patch); for (size_t n = 0; n < kPatchValues; ++n) { size_t k = n % 3; size_t j = ((n - k) / 3) % 14;