Avoid hard-coding kPatchSize. Thanks @Somet2mes for reporting. Fixes #762.

PiperOrigin-RevId: 829308896
This commit is contained in:
Jan Wassenberg 2025-11-07 00:31:59 -08:00 committed by Copybara-Service
parent f8131339a7
commit 3e18db17f4
5 changed files with 38 additions and 32 deletions

View File

@ -295,19 +295,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const size_t model_dim = model_config.vit_config.model_dim; const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width; const size_t patch_width = model_config.vit_config.patch_width;
const size_t num_tokens = model_config.vit_config.seq_len; const size_t num_tokens = model_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3; const size_t patch_area = patch_width * patch_width * 3;
const hwy::Divisor div_patch_dim(patch_width);
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
HWY_DASSERT(activations.x.Cols() == model_dim); HWY_DASSERT(activations.x.Cols() == model_dim);
(void)model_dim; (void)model_dim;
// img/embedding/kernel has original shape (14, 14, 3, 1152) // img/embedding/kernel has original shape (14, 14, 3, 1152)
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3)
// Must be padded, see `DoDecompressA`. // Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size), MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
env.ctx.allocator, MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
for (size_t i = 0; i < num_tokens; ++i) { for (size_t i = 0; i < num_tokens; ++i) {
image.GetPatch(i, image_patches.Row(i)); image.GetPatch(i, div_patch_dim, image_patches.Row(i));
} }
CallMatMul(image_patches, weights.vit_img_embedding_kernel, CallMatMul(image_patches, weights.vit_img_embedding_kernel,
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x); weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);

View File

@ -29,6 +29,7 @@ cc_test(
deps = [ deps = [
":image", ":image",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
], ],
) )

View File

@ -37,8 +37,6 @@
namespace gcpp { namespace gcpp {
namespace { namespace {
// Hardcoded for PaliGemma ViT input.
constexpr size_t kPatchSize = 14;
// Returns the linearly scaled index in [0, to_size) closest to the // Returns the linearly scaled index in [0, to_size) closest to the
// value in [0, from_size). // value in [0, from_size).
@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const {
} }
// Image.data() is H x W x 3. // Image.data() is H x W x 3.
// We want the N-th patch of size kPatchSize x kPatchSize x 3. // We want the N-th patch of size patch_dim x patch_dim x 3.
void Image::GetPatch(size_t patch_num, float* patch) const { void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
float* patch) const {
PROFILER_FUNC; PROFILER_FUNC;
constexpr size_t kNumChannels = 3; constexpr size_t kNumChannels = 3;
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float)); constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float);
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel); const size_t patch_dim = div_patch_dim.GetDivisor();
const size_t kDataSize = width_ * height_ * kNumChannels; const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel); const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
HWY_ASSERT(size() == kDataSize); HWY_ASSERT(size() == width_ * height_ * kNumChannels);
HWY_ASSERT(width_ % kPatchSize == 0); HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
HWY_ASSERT(height_ % kPatchSize == 0); HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
const size_t kNumPatchesPerRow = width_ / kPatchSize; const size_t patches_x = div_patch_dim.Divide(width_);
size_t patch_y = patch_num / kNumPatchesPerRow; size_t patch_y = patch_num / patches_x;
size_t patch_x = patch_num % kNumPatchesPerRow; size_t patch_x = patch_num % patches_x;
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize); HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_));
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow); HWY_DASSERT(0 <= patch_x && patch_x < patches_x);
patch_y *= kPatchSize; patch_y *= patch_dim;
patch_x *= kPatchSize; patch_x *= patch_dim;
// Move `out` and `in` to the start of the patch. // Move `out` and `in` to the start of the patch.
char* out = reinterpret_cast<char*>(patch); char* out = reinterpret_cast<char*>(patch);
@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
in += (((patch_y * width_) + patch_x) * kBytesPerPixel); in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
// Copy the patch one row at a time. // Copy the patch one row at a time.
for (size_t y = 0; y < kPatchSize; ++y) { for (size_t y = 0; y < patch_dim; ++y) {
std::memcpy(out, in, kBytesPerRow); std::memcpy(out, in, bytes_per_row);
out += kBytesPerRow; out += bytes_per_row;
in += in_bytes_to_next_row; in += in_bytes_to_next_row;
} }
} }

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // Divisor
namespace gcpp { namespace gcpp {
@ -44,11 +45,12 @@ class Image {
bool WriteBinary(const std::string& filename) const; bool WriteBinary(const std::string& filename) const;
// Stores the patch for the given patch number in `patch`. // Stores the patch for the given patch number in `patch`.
// Patches are numbered in usual raster-order. E.g. for an image of size // Patches are numbered in usual raster-order. E.g. for an image of size
// 224 x 224, there are 16 x 16 = 256 patches. // 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches.
// `patch` should have space for at least 14 * 14 * 3 = 588 floats. // `patch` should have space for at least patch_dim * patch_dim * 3.
// Requires that Normalize() has been called and that the image width and // Requires that Normalize() has been called and that the image width and
// height are multiples of 14. // height are multiples of patch_dim.
void GetPatch(size_t patch_num, float* patch) const; void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
float* patch) const;
float *data() { return data_.data(); } float *data() { return data_.data(); }
const float *data() const { return data_.data(); } const float *data() const { return data_.data(); }

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "hwy/base.h"
namespace gcpp { namespace gcpp {
namespace { namespace {
@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) {
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122)); EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
// Extract two patches. // Extract two patches.
float patch[588]; float patch[588];
image.GetPatch(0, patch); const hwy::Divisor div_patch_dim(14);
image.GetPatch(0, div_patch_dim, patch);
EXPECT_EQ(patch[0], Normalize(160)); EXPECT_EQ(patch[0], Normalize(160));
EXPECT_EQ(patch[1], Normalize(184)); EXPECT_EQ(patch[1], Normalize(184));
EXPECT_EQ(patch[2], Normalize(188)); EXPECT_EQ(patch[2], Normalize(188));
image.GetPatch(18, patch); image.GetPatch(18, div_patch_dim, patch);
// Check the first row of the patch. // Check the first row of the patch.
for (size_t i = 0; i < 14 * 3; ++i) { for (size_t i = 0; i < 14 * 3; ++i) {
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]); EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
@ -108,14 +110,15 @@ TEST(ImageTest, Non224) {
// Extract two patches. // Extract two patches.
const size_t kPatchValues = 14 * 14 * 3; // = 588 const size_t kPatchValues = 14 * 14 * 3; // = 588
float patch[kPatchValues]; float patch[kPatchValues];
const hwy::Divisor div_patch_dim(14);
// Patch 0 is just the "start" of the image. // Patch 0 is just the "start" of the image.
image.GetPatch(0, patch); image.GetPatch(0, div_patch_dim, patch);
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6); EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6); EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6); EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its // The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
// pixel coordinates are offset by (14, 28). // pixel coordinates are offset by (14, 28).
image.GetPatch(6, patch); image.GetPatch(6, div_patch_dim, patch);
for (size_t n = 0; n < kPatchValues; ++n) { for (size_t n = 0; n < kPatchValues; ++n) {
size_t k = n % 3; size_t k = n % 3;
size_t j = ((n - k) / 3) % 14; size_t j = ((n - k) / 3) % 14;