Fix paligemma, update its test

Must not pass image tokens to the EmbedMMToken used for text.
Caught by next presubmit test.

paligemma_test: move function bodies into class, regroup variables
PiperOrigin-RevId: 770040014
This commit is contained in:
Jan Wassenberg 2025-06-11 02:11:29 -07:00 committed by Copybara-Service
parent ec02726cf7
commit b84149310b
2 changed files with 64 additions and 64 deletions

View File

@ -319,12 +319,9 @@ static HWY_NOINLINE void Transformer(
}
}
size_t image_token_position = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
image_token_position =
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
/*pos_in_prompt=*/0, config, weights, activations.x,
runtime_config.image_tokens, image_token_position);
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
/*pos_in_prompt=*/0, config, weights, activations.x);
}
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
@ -421,6 +418,7 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
// User decided to stop: set next token to primary EOS.
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
token = config.eos_id;
HWY_DASSERT(config.IsEOS(token));
}
// Primary or secondary EOS: mark query as EOS.

View File

@ -13,7 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdio>
#include <stdio.h>
#include <memory>
#include <string>
#include <vector>
@ -40,68 +41,69 @@ GemmaEnv* s_env = nullptr;
class PaliGemmaTest : public ::testing::Test {
protected:
void InitVit(const std::string& path);
std::string GemmaReply(const std::string& prompt_text) const;
void TestQuestion(const char* question, const char* expected_substring);
void InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const Gemma& gemma = *(s_env->GetGemma());
const ModelConfig& config = gemma.GetModelConfig();
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
Image image;
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
}
std::string GemmaReply(const std::string& prompt_text) const {
const Gemma& model = *(s_env->GetGemma());
s_env->MutableGen().seed(0x12345678);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
// PrefixLM sees/attends to all tokens.
.prefill_tbatch_size = tokens.size(),
.gen = &s_env->MutableGen(),
.verbosity = 0,
.stream_token = stream_token,
.image_tokens = image_tokens_.get()};
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
s_env->MutableKVCache(), timing_info);
return response;
}
void TestQuestion(const char* question, const char* expected_substring) {
ASSERT_NE(s_env->GetGemma(), nullptr);
std::string path = "paligemma/testdata/image.ppm";
InitVit(path);
const std::string reply = GemmaReply(question);
fprintf(stderr, "'%s'\n\n", reply.c_str());
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
}
std::unique_ptr<ImageTokens> image_tokens_;
};
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const Gemma& gemma = *(s_env->GetGemma());
const ModelConfig& config = gemma.GetModelConfig();
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
Image image;
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
}
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
const Gemma& model = *(s_env->GetGemma());
s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.gen = &s_env->MutableGen(),
.verbosity = 0};
runtime_config.image_tokens = image_tokens_.get();
size_t abs_pos = 0;
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
runtime_config.stream_token = stream_token,
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens;
runtime_config.prefill_tbatch_size = num_tokens;
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, abs_pos, prefix_end,
s_env->MutableKVCache(), timing_info);
return response;
}
void PaliGemmaTest::TestQuestion(const char* question,
const char* expected_substring) {
ASSERT_NE(s_env->GetGemma(), nullptr);
std::string path = "paligemma/testdata/image.ppm";
InitVit(path);
const std::string reply = GemmaReply(question);
fprintf(stderr, "'%s'\n\n", reply.c_str());
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
}
TEST_F(PaliGemmaTest, QueryObjects) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const char* question = "answer en What objects are in the image?";