mirror of https://github.com/google/gemma.cpp.git
Fix paligemma, update its test
Must not pass image tokens to the EmbedMMToken used for text. Caught by next presubmit test. paligemma_test: move function bodies into class, regroup variables PiperOrigin-RevId: 770040014
This commit is contained in:
parent
ec02726cf7
commit
b84149310b
|
|
@ -319,12 +319,9 @@ static HWY_NOINLINE void Transformer(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t image_token_position = 0;
|
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
image_token_position =
|
|
||||||
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
||||||
/*pos_in_prompt=*/0, config, weights, activations.x,
|
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||||
runtime_config.image_tokens, image_token_position);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||||
|
|
@ -421,6 +418,7 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
||||||
// User decided to stop: set next token to primary EOS.
|
// User decided to stop: set next token to primary EOS.
|
||||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
|
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
|
||||||
token = config.eos_id;
|
token = config.eos_id;
|
||||||
|
HWY_DASSERT(config.IsEOS(token));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Primary or secondary EOS: mark query as EOS.
|
// Primary or secondary EOS: mark query as EOS.
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cstdio>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -40,67 +41,68 @@ GemmaEnv* s_env = nullptr;
|
||||||
|
|
||||||
class PaliGemmaTest : public ::testing::Test {
|
class PaliGemmaTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void InitVit(const std::string& path);
|
void InitVit(const std::string& path) {
|
||||||
std::string GemmaReply(const std::string& prompt_text) const;
|
|
||||||
void TestQuestion(const char* question, const char* expected_substring);
|
|
||||||
|
|
||||||
std::unique_ptr<ImageTokens> image_tokens_;
|
|
||||||
};
|
|
||||||
|
|
||||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
const Gemma& gemma = *(s_env->GetGemma());
|
const Gemma& gemma = *(s_env->GetGemma());
|
||||||
const ModelConfig& config = gemma.GetModelConfig();
|
const ModelConfig& config = gemma.GetModelConfig();
|
||||||
|
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||||
|
|
||||||
image_tokens_ = std::make_unique<ImageTokens>(
|
image_tokens_ = std::make_unique<ImageTokens>(
|
||||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||||
MatPadding::kPacked);
|
MatPadding::kPacked);
|
||||||
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
||||||
Image image;
|
Image image;
|
||||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
const size_t image_size = config.vit_config.image_size;
|
const size_t image_size = config.vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
|
||||||
|
.verbosity = 0};
|
||||||
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
std::string GemmaReply(const std::string& prompt_text) const {
|
||||||
const Gemma& model = *(s_env->GetGemma());
|
const Gemma& model = *(s_env->GetGemma());
|
||||||
s_env->MutableGen().seed(0x12345678);
|
s_env->MutableGen().seed(0x12345678);
|
||||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
|
||||||
.gen = &s_env->MutableGen(),
|
|
||||||
.verbosity = 0};
|
|
||||||
runtime_config.image_tokens = image_tokens_.get();
|
|
||||||
size_t abs_pos = 0;
|
|
||||||
std::string mutable_prompt = prompt_text;
|
|
||||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
HWY_ASSERT(
|
||||||
|
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
response += token_text;
|
response += token_text;
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
runtime_config.stream_token = stream_token,
|
|
||||||
|
std::string mutable_prompt = prompt_text;
|
||||||
|
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||||
size_t num_tokens = tokens.size();
|
|
||||||
size_t prefix_end = num_tokens;
|
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||||
runtime_config.prefill_tbatch_size = num_tokens;
|
// PrefixLM sees/attends to all tokens.
|
||||||
|
.prefill_tbatch_size = tokens.size(),
|
||||||
|
.gen = &s_env->MutableGen(),
|
||||||
|
.verbosity = 0,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.image_tokens = image_tokens_.get()};
|
||||||
|
|
||||||
|
const size_t prefix_end = tokens.size();
|
||||||
TimingInfo timing_info = {.verbosity = 0};
|
TimingInfo timing_info = {.verbosity = 0};
|
||||||
model.Generate(runtime_config, tokens, abs_pos, prefix_end,
|
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||||
s_env->MutableKVCache(), timing_info);
|
s_env->MutableKVCache(), timing_info);
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PaliGemmaTest::TestQuestion(const char* question,
|
void TestQuestion(const char* question, const char* expected_substring) {
|
||||||
const char* expected_substring) {
|
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
std::string path = "paligemma/testdata/image.ppm";
|
std::string path = "paligemma/testdata/image.ppm";
|
||||||
InitVit(path);
|
InitVit(path);
|
||||||
const std::string reply = GemmaReply(question);
|
const std::string reply = GemmaReply(question);
|
||||||
fprintf(stderr, "'%s'\n\n", reply.c_str());
|
fprintf(stderr, "'%s'\n\n", reply.c_str());
|
||||||
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
|
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ImageTokens> image_tokens_;
|
||||||
|
};
|
||||||
|
|
||||||
TEST_F(PaliGemmaTest, QueryObjects) {
|
TEST_F(PaliGemmaTest, QueryObjects) {
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue