diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 8f49717..b72e31f 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -319,12 +319,9 @@ static HWY_NOINLINE void Transformer( } } - size_t image_token_position = 0; for (size_t qi = 0; qi < num_queries; ++qi) { - image_token_position = - EmbedMMToken(queries_token[qi], qi, queries_pos[qi], - /*pos_in_prompt=*/0, config, weights, activations.x, - runtime_config.image_tokens, image_token_position); + EmbedMMToken(queries_token[qi], qi, queries_pos[qi], + /*pos_in_prompt=*/0, config, weights, activations.x); } for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { @@ -421,6 +418,7 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token, // User decided to stop: set next token to primary EOS. if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) { token = config.eos_id; + HWY_DASSERT(config.IsEOS(token)); } // Primary or secondary EOS: mark query as EOS. diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index e681e00..9491475 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -13,7 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include + #include #include #include @@ -40,68 +41,69 @@ GemmaEnv* s_env = nullptr; class PaliGemmaTest : public ::testing::Test { protected: - void InitVit(const std::string& path); - std::string GemmaReply(const std::string& prompt_text) const; - void TestQuestion(const char* question, const char* expected_substring); + void InitVit(const std::string& path) { + ASSERT_NE(s_env->GetGemma(), nullptr); + const Gemma& gemma = *(s_env->GetGemma()); + const ModelConfig& config = gemma.GetModelConfig(); + HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); + + image_tokens_ = std::make_unique( + "image", Extents2D(config.vit_config.seq_len, config.model_dim), + MatPadding::kPacked); + image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs); + Image image; + HWY_ASSERT(image.ReadPPM(path)); + const size_t image_size = config.vit_config.image_size; + image.Resize(image_size, image_size); + RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), + .verbosity = 0}; + gemma.GenerateImageTokens(runtime_config, image, *image_tokens_); + } + + std::string GemmaReply(const std::string& prompt_text) const { + const Gemma& model = *(s_env->GetGemma()); + s_env->MutableGen().seed(0x12345678); + + std::string response; + auto stream_token = [&](int token, float) { + std::string token_text; + HWY_ASSERT( + model.Tokenizer().Decode(std::vector{token}, &token_text)); + response += token_text; + return true; + }; + + std::string mutable_prompt = prompt_text; + std::vector tokens = s_env->WrapAndTokenize(mutable_prompt); + tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); + + RuntimeConfig runtime_config = {.max_generated_tokens = 512, + // PrefixLM sees/attends to all tokens. + .prefill_tbatch_size = tokens.size(), + .gen = &s_env->MutableGen(), + .verbosity = 0, + .stream_token = stream_token, + .image_tokens = image_tokens_.get()}; + + const size_t prefix_end = tokens.size(); + TimingInfo timing_info = {.verbosity = 0}; + model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, + s_env->MutableKVCache(), timing_info); + return response; + } + + void TestQuestion(const char* question, const char* expected_substring) { + ASSERT_NE(s_env->GetGemma(), nullptr); + std::string path = "paligemma/testdata/image.ppm"; + InitVit(path); + const std::string reply = GemmaReply(question); + fprintf(stderr, "'%s'\n\n", reply.c_str()); + EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT + } std::unique_ptr image_tokens_; }; -void PaliGemmaTest::InitVit(const std::string& path) { - ASSERT_NE(s_env->GetGemma(), nullptr); - const Gemma& gemma = *(s_env->GetGemma()); - const ModelConfig& config = gemma.GetModelConfig(); - image_tokens_ = std::make_unique( - "image", Extents2D(config.vit_config.seq_len, config.model_dim), - MatPadding::kPacked); - image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs); - Image image; - HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); - HWY_ASSERT(image.ReadPPM(path)); - const size_t image_size = config.vit_config.image_size; - image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; - gemma.GenerateImageTokens(runtime_config, image, *image_tokens_); -} - -std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ - const Gemma& model = *(s_env->GetGemma()); - s_env->MutableGen().seed(0x12345678); - RuntimeConfig runtime_config = {.max_generated_tokens = 512, - .gen = &s_env->MutableGen(), - .verbosity = 0}; - runtime_config.image_tokens = image_tokens_.get(); - size_t abs_pos = 0; - std::string mutable_prompt = prompt_text; - std::vector tokens = s_env->WrapAndTokenize(mutable_prompt); - std::string response; - auto stream_token = [&](int token, float) { - std::string token_text; - HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); - response += token_text; - return true; - }; - runtime_config.stream_token = stream_token, - tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); - size_t num_tokens = tokens.size(); - size_t prefix_end = num_tokens; - runtime_config.prefill_tbatch_size = num_tokens; - TimingInfo timing_info = {.verbosity = 0}; - model.Generate(runtime_config, tokens, abs_pos, prefix_end, - s_env->MutableKVCache(), timing_info); - return response; -} - -void PaliGemmaTest::TestQuestion(const char* question, - const char* expected_substring) { - ASSERT_NE(s_env->GetGemma(), nullptr); - std::string path = "paligemma/testdata/image.ppm"; - InitVit(path); - const std::string reply = GemmaReply(question); - fprintf(stderr, "'%s'\n\n", reply.c_str()); - EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT -} - TEST_F(PaliGemmaTest, QueryObjects) { ASSERT_NE(s_env->GetGemma(), nullptr); const char* question = "answer en What objects are in the image?";