mirror of https://github.com/google/gemma.cpp.git
Avoid hard-coding kPatchSize. Thanks @Somet2mes for reporting. Fixes #762.
PiperOrigin-RevId: 829308896
This commit is contained in:
parent
f8131339a7
commit
3e18db17f4
|
|
@ -295,19 +295,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
const size_t model_dim = model_config.vit_config.model_dim;
|
||||
const size_t patch_width = model_config.vit_config.patch_width;
|
||||
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
const size_t patch_area = patch_width * patch_width * 3;
|
||||
const hwy::Divisor div_patch_dim(patch_width);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
|
||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||
(void)model_dim;
|
||||
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||
// image_patches is (256, 14 * 14 * 3)
|
||||
// Must be padded, see `DoDecompressA`.
|
||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
|
||||
env.ctx.allocator, MatPadding::kOdd);
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
image.GetPatch(i, image_patches.Row(i));
|
||||
image.GetPatch(i, div_patch_dim, image_patches.Row(i));
|
||||
}
|
||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ cc_test(
|
|||
deps = [
|
||||
":image",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
// Hardcoded for PaliGemma ViT input.
|
||||
constexpr size_t kPatchSize = 14;
|
||||
|
||||
// Returns the linearly scaled index in [0, to_size) closest to the
|
||||
// value in [0, from_size).
|
||||
|
|
@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const {
|
|||
}
|
||||
|
||||
// Image.data() is H x W x 3.
|
||||
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
|
||||
void Image::GetPatch(size_t patch_num, float* patch) const {
|
||||
// We want the N-th patch of size patch_dim x patch_dim x 3.
|
||||
void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
|
||||
float* patch) const {
|
||||
PROFILER_FUNC;
|
||||
constexpr size_t kNumChannels = 3;
|
||||
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float));
|
||||
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel);
|
||||
const size_t kDataSize = width_ * height_ * kNumChannels;
|
||||
constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float);
|
||||
const size_t patch_dim = div_patch_dim.GetDivisor();
|
||||
const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
|
||||
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
|
||||
HWY_ASSERT(size() == kDataSize);
|
||||
HWY_ASSERT(width_ % kPatchSize == 0);
|
||||
HWY_ASSERT(height_ % kPatchSize == 0);
|
||||
const size_t kNumPatchesPerRow = width_ / kPatchSize;
|
||||
size_t patch_y = patch_num / kNumPatchesPerRow;
|
||||
size_t patch_x = patch_num % kNumPatchesPerRow;
|
||||
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize);
|
||||
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow);
|
||||
patch_y *= kPatchSize;
|
||||
patch_x *= kPatchSize;
|
||||
HWY_ASSERT(size() == width_ * height_ * kNumChannels);
|
||||
HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
|
||||
HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
|
||||
const size_t patches_x = div_patch_dim.Divide(width_);
|
||||
size_t patch_y = patch_num / patches_x;
|
||||
size_t patch_x = patch_num % patches_x;
|
||||
HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_));
|
||||
HWY_DASSERT(0 <= patch_x && patch_x < patches_x);
|
||||
patch_y *= patch_dim;
|
||||
patch_x *= patch_dim;
|
||||
|
||||
// Move `out` and `in` to the start of the patch.
|
||||
char* out = reinterpret_cast<char*>(patch);
|
||||
|
|
@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
|
|||
in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
|
||||
|
||||
// Copy the patch one row at a time.
|
||||
for (size_t y = 0; y < kPatchSize; ++y) {
|
||||
std::memcpy(out, in, kBytesPerRow);
|
||||
out += kBytesPerRow;
|
||||
for (size_t y = 0; y < patch_dim; ++y) {
|
||||
std::memcpy(out, in, bytes_per_row);
|
||||
out += bytes_per_row;
|
||||
in += in_bytes_to_next_row;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // Divisor
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -44,11 +45,12 @@ class Image {
|
|||
bool WriteBinary(const std::string& filename) const;
|
||||
// Stores the patch for the given patch number in `patch`.
|
||||
// Patches are numbered in usual raster-order. E.g. for an image of size
|
||||
// 224 x 224, there are 16 x 16 = 256 patches.
|
||||
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
|
||||
// 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches.
|
||||
// `patch` should have space for at least patch_dim * patch_dim * 3.
|
||||
// Requires that Normalize() has been called and that the image width and
|
||||
// height are multiples of 14.
|
||||
void GetPatch(size_t patch_num, float* patch) const;
|
||||
// height are multiples of patch_dim.
|
||||
void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
|
||||
float* patch) const;
|
||||
|
||||
float *data() { return data_.data(); }
|
||||
const float *data() const { return data_.data(); }
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
|
@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) {
|
|||
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
|
||||
// Extract two patches.
|
||||
float patch[588];
|
||||
image.GetPatch(0, patch);
|
||||
const hwy::Divisor div_patch_dim(14);
|
||||
image.GetPatch(0, div_patch_dim, patch);
|
||||
EXPECT_EQ(patch[0], Normalize(160));
|
||||
EXPECT_EQ(patch[1], Normalize(184));
|
||||
EXPECT_EQ(patch[2], Normalize(188));
|
||||
image.GetPatch(18, patch);
|
||||
image.GetPatch(18, div_patch_dim, patch);
|
||||
// Check the first row of the patch.
|
||||
for (size_t i = 0; i < 14 * 3; ++i) {
|
||||
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
|
||||
|
|
@ -108,14 +110,15 @@ TEST(ImageTest, Non224) {
|
|||
// Extract two patches.
|
||||
const size_t kPatchValues = 14 * 14 * 3; // = 588
|
||||
float patch[kPatchValues];
|
||||
const hwy::Divisor div_patch_dim(14);
|
||||
// Patch 0 is just the "start" of the image.
|
||||
image.GetPatch(0, patch);
|
||||
image.GetPatch(0, div_patch_dim, patch);
|
||||
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
|
||||
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
|
||||
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
|
||||
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
|
||||
// pixel coordinates are offset by (14, 28).
|
||||
image.GetPatch(6, patch);
|
||||
image.GetPatch(6, div_patch_dim, patch);
|
||||
for (size_t n = 0; n < kPatchValues; ++n) {
|
||||
size_t k = n % 3;
|
||||
size_t j = ((n - k) / 3) % 14;
|
||||
|
|
|
|||
Loading…
Reference in New Issue