Avoid hard-coding kPatchSize. Thanks @Somet2mes for reporting. Fixes #762.

PiperOrigin-RevId: 829308896
This commit is contained in:
Jan Wassenberg 2025-11-07 00:31:59 -08:00 committed by Copybara-Service
parent f8131339a7
commit 3e18db17f4
5 changed files with 38 additions and 32 deletions

View File

@ -295,19 +295,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width;
const size_t num_tokens = model_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
const size_t patch_area = patch_width * patch_width * 3;
const hwy::Divisor div_patch_dim(patch_width);
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
HWY_DASSERT(activations.x.Cols() == model_dim);
(void)model_dim;
// img/embedding/kernel has original shape (14, 14, 3, 1152)
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3)
// Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
env.ctx.allocator, MatPadding::kOdd);
for (size_t i = 0; i < num_tokens; ++i) {
image.GetPatch(i, image_patches.Row(i));
image.GetPatch(i, div_patch_dim, image_patches.Row(i));
}
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);

View File

@ -29,6 +29,7 @@ cc_test(
deps = [
":image",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
],
)

View File

@ -37,8 +37,6 @@
namespace gcpp {
namespace {
// Hardcoded for PaliGemma ViT input.
constexpr size_t kPatchSize = 14;
// Returns the linearly scaled index in [0, to_size) closest to the
// value in [0, from_size).
@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const {
}
// Image.data() is H x W x 3.
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
void Image::GetPatch(size_t patch_num, float* patch) const {
// We want the N-th patch of size patch_dim x patch_dim x 3.
void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
float* patch) const {
PROFILER_FUNC;
constexpr size_t kNumChannels = 3;
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float));
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel);
const size_t kDataSize = width_ * height_ * kNumChannels;
constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float);
const size_t patch_dim = div_patch_dim.GetDivisor();
const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
HWY_ASSERT(size() == kDataSize);
HWY_ASSERT(width_ % kPatchSize == 0);
HWY_ASSERT(height_ % kPatchSize == 0);
const size_t kNumPatchesPerRow = width_ / kPatchSize;
size_t patch_y = patch_num / kNumPatchesPerRow;
size_t patch_x = patch_num % kNumPatchesPerRow;
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize);
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow);
patch_y *= kPatchSize;
patch_x *= kPatchSize;
HWY_ASSERT(size() == width_ * height_ * kNumChannels);
HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
const size_t patches_x = div_patch_dim.Divide(width_);
size_t patch_y = patch_num / patches_x;
size_t patch_x = patch_num % patches_x;
HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_));
HWY_DASSERT(0 <= patch_x && patch_x < patches_x);
patch_y *= patch_dim;
patch_x *= patch_dim;
// Move `out` and `in` to the start of the patch.
char* out = reinterpret_cast<char*>(patch);
@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
// Copy the patch one row at a time.
for (size_t y = 0; y < kPatchSize; ++y) {
std::memcpy(out, in, kBytesPerRow);
out += kBytesPerRow;
for (size_t y = 0; y < patch_dim; ++y) {
std::memcpy(out, in, bytes_per_row);
out += bytes_per_row;
in += in_bytes_to_next_row;
}
}

View File

@ -21,6 +21,7 @@
#include <vector>
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // Divisor
namespace gcpp {
@ -44,11 +45,12 @@ class Image {
bool WriteBinary(const std::string& filename) const;
// Stores the patch for the given patch number in `patch`.
// Patches are numbered in usual raster-order. E.g. for an image of size
// 224 x 224, there are 16 x 16 = 256 patches.
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
// 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches.
// `patch` should have space for at least patch_dim * patch_dim * 3.
// Requires that Normalize() has been called and that the image width and
// height are multiples of 14.
void GetPatch(size_t patch_num, float* patch) const;
// height are multiples of patch_dim.
void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
float* patch) const;
float *data() { return data_.data(); }
const float *data() const { return data_.data(); }

View File

@ -20,6 +20,7 @@
#include <vector>
#include "gtest/gtest.h"
#include "hwy/base.h"
namespace gcpp {
namespace {
@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) {
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
// Extract two patches.
float patch[588];
image.GetPatch(0, patch);
const hwy::Divisor div_patch_dim(14);
image.GetPatch(0, div_patch_dim, patch);
EXPECT_EQ(patch[0], Normalize(160));
EXPECT_EQ(patch[1], Normalize(184));
EXPECT_EQ(patch[2], Normalize(188));
image.GetPatch(18, patch);
image.GetPatch(18, div_patch_dim, patch);
// Check the first row of the patch.
for (size_t i = 0; i < 14 * 3; ++i) {
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
@ -108,14 +110,15 @@ TEST(ImageTest, Non224) {
// Extract two patches.
const size_t kPatchValues = 14 * 14 * 3; // = 588
float patch[kPatchValues];
const hwy::Divisor div_patch_dim(14);
// Patch 0 is just the "start" of the image.
image.GetPatch(0, patch);
image.GetPatch(0, div_patch_dim, patch);
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
// pixel coordinates are offset by (14, 28).
image.GetPatch(6, patch);
image.GetPatch(6, div_patch_dim, patch);
for (size_t n = 0; n < kPatchValues; ++n) {
size_t k = n % 3;
size_t j = ((n - k) / 3) % 14;