mirror of https://github.com/google/gemma.cpp.git
Fix paligemma, update its test
Must not pass image tokens to the EmbedMMToken used for text. Caught by next presubmit test. paligemma_test: move function bodies into class, regroup variables PiperOrigin-RevId: 770040014
This commit is contained in:
parent
ec02726cf7
commit
b84149310b
|
|
@ -319,12 +319,9 @@ static HWY_NOINLINE void Transformer(
|
|||
}
|
||||
}
|
||||
|
||||
size_t image_token_position = 0;
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
image_token_position =
|
||||
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x,
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
|
|
@ -421,6 +418,7 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
|||
// User decided to stop: set next token to primary EOS.
|
||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
|
||||
token = config.eos_id;
|
||||
HWY_DASSERT(config.IsEOS(token));
|
||||
}
|
||||
|
||||
// Primary or secondary EOS: mark query as EOS.
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdio>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -40,68 +41,69 @@ GemmaEnv* s_env = nullptr;
|
|||
|
||||
class PaliGemmaTest : public ::testing::Test {
|
||||
protected:
|
||||
void InitVit(const std::string& path);
|
||||
std::string GemmaReply(const std::string& prompt_text) const;
|
||||
void TestQuestion(const char* question, const char* expected_substring);
|
||||
void InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const Gemma& gemma = *(s_env->GetGemma());
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
MatPadding::kPacked);
|
||||
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
||||
Image image;
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||
}
|
||||
|
||||
std::string GemmaReply(const std::string& prompt_text) const {
|
||||
const Gemma& model = *(s_env->GetGemma());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
// PrefixLM sees/attends to all tokens.
|
||||
.prefill_tbatch_size = tokens.size(),
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.image_tokens = image_tokens_.get()};
|
||||
|
||||
const size_t prefix_end = tokens.size();
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||
s_env->MutableKVCache(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
void TestQuestion(const char* question, const char* expected_substring) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
std::string path = "paligemma/testdata/image.ppm";
|
||||
InitVit(path);
|
||||
const std::string reply = GemmaReply(question);
|
||||
fprintf(stderr, "'%s'\n\n", reply.c_str());
|
||||
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
|
||||
}
|
||||
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
};
|
||||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const Gemma& gemma = *(s_env->GetGemma());
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
MatPadding::kPacked);
|
||||
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
||||
Image image;
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||
}
|
||||
|
||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
const Gemma& model = *(s_env->GetGemma());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0};
|
||||
runtime_config.image_tokens = image_tokens_.get();
|
||||
size_t abs_pos = 0;
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
runtime_config.stream_token = stream_token,
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
size_t prefix_end = num_tokens;
|
||||
runtime_config.prefill_tbatch_size = num_tokens;
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, abs_pos, prefix_end,
|
||||
s_env->MutableKVCache(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
void PaliGemmaTest::TestQuestion(const char* question,
|
||||
const char* expected_substring) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
std::string path = "paligemma/testdata/image.ppm";
|
||||
InitVit(path);
|
||||
const std::string reply = GemmaReply(question);
|
||||
fprintf(stderr, "'%s'\n\n", reply.c_str());
|
||||
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
|
||||
}
|
||||
|
||||
TEST_F(PaliGemmaTest, QueryObjects) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const char* question = "answer en What objects are in the image?";
|
||||
|
|
|
|||
Loading…
Reference in New Issue