mirror of https://github.com/google/gemma.cpp.git
Avoid hard-coding kPatchSize. Thanks @Somet2mes for reporting. Fixes #762.
PiperOrigin-RevId: 829308896
This commit is contained in:
parent
f8131339a7
commit
3e18db17f4
|
|
@ -295,19 +295,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
const size_t model_dim = model_config.vit_config.model_dim;
|
const size_t model_dim = model_config.vit_config.model_dim;
|
||||||
const size_t patch_width = model_config.vit_config.patch_width;
|
const size_t patch_width = model_config.vit_config.patch_width;
|
||||||
const size_t num_tokens = model_config.vit_config.seq_len;
|
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||||
const size_t patch_size = patch_width * patch_width * 3;
|
const size_t patch_area = patch_width * patch_width * 3;
|
||||||
|
const hwy::Divisor div_patch_dim(patch_width);
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
|
||||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||||
(void)model_dim;
|
(void)model_dim;
|
||||||
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
||||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||||
// image_patches is (256, 14 * 14 * 3)
|
// image_patches is (256, 14 * 14 * 3)
|
||||||
// Must be padded, see `DoDecompressA`.
|
// Must be padded, see `DoDecompressA`.
|
||||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
|
||||||
env.ctx.allocator, MatPadding::kOdd);
|
env.ctx.allocator, MatPadding::kOdd);
|
||||||
for (size_t i = 0; i < num_tokens; ++i) {
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
image.GetPatch(i, image_patches.Row(i));
|
image.GetPatch(i, div_patch_dim, image_patches.Row(i));
|
||||||
}
|
}
|
||||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||||
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
|
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":image",
|
":image",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
// Hardcoded for PaliGemma ViT input.
|
|
||||||
constexpr size_t kPatchSize = 14;
|
|
||||||
|
|
||||||
// Returns the linearly scaled index in [0, to_size) closest to the
|
// Returns the linearly scaled index in [0, to_size) closest to the
|
||||||
// value in [0, from_size).
|
// value in [0, from_size).
|
||||||
|
|
@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Image.data() is H x W x 3.
|
// Image.data() is H x W x 3.
|
||||||
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
|
// We want the N-th patch of size patch_dim x patch_dim x 3.
|
||||||
void Image::GetPatch(size_t patch_num, float* patch) const {
|
void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
|
||||||
|
float* patch) const {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
constexpr size_t kNumChannels = 3;
|
constexpr size_t kNumChannels = 3;
|
||||||
constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float));
|
constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float);
|
||||||
constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel);
|
const size_t patch_dim = div_patch_dim.GetDivisor();
|
||||||
const size_t kDataSize = width_ * height_ * kNumChannels;
|
const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
|
||||||
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
|
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
|
||||||
HWY_ASSERT(size() == kDataSize);
|
HWY_ASSERT(size() == width_ * height_ * kNumChannels);
|
||||||
HWY_ASSERT(width_ % kPatchSize == 0);
|
HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
|
||||||
HWY_ASSERT(height_ % kPatchSize == 0);
|
HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
|
||||||
const size_t kNumPatchesPerRow = width_ / kPatchSize;
|
const size_t patches_x = div_patch_dim.Divide(width_);
|
||||||
size_t patch_y = patch_num / kNumPatchesPerRow;
|
size_t patch_y = patch_num / patches_x;
|
||||||
size_t patch_x = patch_num % kNumPatchesPerRow;
|
size_t patch_x = patch_num % patches_x;
|
||||||
HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize);
|
HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_));
|
||||||
HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow);
|
HWY_DASSERT(0 <= patch_x && patch_x < patches_x);
|
||||||
patch_y *= kPatchSize;
|
patch_y *= patch_dim;
|
||||||
patch_x *= kPatchSize;
|
patch_x *= patch_dim;
|
||||||
|
|
||||||
// Move `out` and `in` to the start of the patch.
|
// Move `out` and `in` to the start of the patch.
|
||||||
char* out = reinterpret_cast<char*>(patch);
|
char* out = reinterpret_cast<char*>(patch);
|
||||||
|
|
@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
|
||||||
in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
|
in += (((patch_y * width_) + patch_x) * kBytesPerPixel);
|
||||||
|
|
||||||
// Copy the patch one row at a time.
|
// Copy the patch one row at a time.
|
||||||
for (size_t y = 0; y < kPatchSize; ++y) {
|
for (size_t y = 0; y < patch_dim; ++y) {
|
||||||
std::memcpy(out, in, kBytesPerRow);
|
std::memcpy(out, in, bytes_per_row);
|
||||||
out += kBytesPerRow;
|
out += bytes_per_row;
|
||||||
in += in_bytes_to_next_row;
|
in += in_bytes_to_next_row;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
|
#include "hwy/base.h" // Divisor
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -44,11 +45,12 @@ class Image {
|
||||||
bool WriteBinary(const std::string& filename) const;
|
bool WriteBinary(const std::string& filename) const;
|
||||||
// Stores the patch for the given patch number in `patch`.
|
// Stores the patch for the given patch number in `patch`.
|
||||||
// Patches are numbered in usual raster-order. E.g. for an image of size
|
// Patches are numbered in usual raster-order. E.g. for an image of size
|
||||||
// 224 x 224, there are 16 x 16 = 256 patches.
|
// 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches.
|
||||||
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
|
// `patch` should have space for at least patch_dim * patch_dim * 3.
|
||||||
// Requires that Normalize() has been called and that the image width and
|
// Requires that Normalize() has been called and that the image width and
|
||||||
// height are multiples of 14.
|
// height are multiples of patch_dim.
|
||||||
void GetPatch(size_t patch_num, float* patch) const;
|
void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
|
||||||
|
float* patch) const;
|
||||||
|
|
||||||
float *data() { return data_.data(); }
|
float *data() { return data_.data(); }
|
||||||
const float *data() const { return data_.data(); }
|
const float *data() const { return data_.data(); }
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) {
|
||||||
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
|
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
|
||||||
// Extract two patches.
|
// Extract two patches.
|
||||||
float patch[588];
|
float patch[588];
|
||||||
image.GetPatch(0, patch);
|
const hwy::Divisor div_patch_dim(14);
|
||||||
|
image.GetPatch(0, div_patch_dim, patch);
|
||||||
EXPECT_EQ(patch[0], Normalize(160));
|
EXPECT_EQ(patch[0], Normalize(160));
|
||||||
EXPECT_EQ(patch[1], Normalize(184));
|
EXPECT_EQ(patch[1], Normalize(184));
|
||||||
EXPECT_EQ(patch[2], Normalize(188));
|
EXPECT_EQ(patch[2], Normalize(188));
|
||||||
image.GetPatch(18, patch);
|
image.GetPatch(18, div_patch_dim, patch);
|
||||||
// Check the first row of the patch.
|
// Check the first row of the patch.
|
||||||
for (size_t i = 0; i < 14 * 3; ++i) {
|
for (size_t i = 0; i < 14 * 3; ++i) {
|
||||||
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
|
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
|
||||||
|
|
@ -108,14 +110,15 @@ TEST(ImageTest, Non224) {
|
||||||
// Extract two patches.
|
// Extract two patches.
|
||||||
const size_t kPatchValues = 14 * 14 * 3; // = 588
|
const size_t kPatchValues = 14 * 14 * 3; // = 588
|
||||||
float patch[kPatchValues];
|
float patch[kPatchValues];
|
||||||
|
const hwy::Divisor div_patch_dim(14);
|
||||||
// Patch 0 is just the "start" of the image.
|
// Patch 0 is just the "start" of the image.
|
||||||
image.GetPatch(0, patch);
|
image.GetPatch(0, div_patch_dim, patch);
|
||||||
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
|
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
|
||||||
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
|
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
|
||||||
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
|
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
|
||||||
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
|
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
|
||||||
// pixel coordinates are offset by (14, 28).
|
// pixel coordinates are offset by (14, 28).
|
||||||
image.GetPatch(6, patch);
|
image.GetPatch(6, div_patch_dim, patch);
|
||||||
for (size_t n = 0; n < kPatchValues; ++n) {
|
for (size_t n = 0; n < kPatchValues; ++n) {
|
||||||
size_t k = n % 3;
|
size_t k = n % 3;
|
||||||
size_t j = ((n - k) / 3) % 14;
|
size_t j = ((n - k) / 3) % 14;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue