Fix sign comparison warnings

PiperOrigin-RevId: 627299902
This commit is contained in:
Jan Wassenberg 2024-04-23 01:16:15 -07:00 committed by Copybara-Service
parent ca971ef50f
commit 3bf22abb22
2 changed files with 8 additions and 7 deletions

View File

@ -1136,8 +1136,8 @@ template <class TConfig>
void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (int i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], i);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
@ -1146,9 +1146,10 @@ void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
}
return a.second < b.second;
});
for (int i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", i + 1, sorted[i].second,
TOKEN(sorted[i].second), sorted[i].first, logits[sorted[i].second]);
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
sorted[i].second, TOKEN(sorted[i].second), sorted[i].first,
logits[sorted[i].second]);
}
}

View File

@ -98,7 +98,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns
size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
int prompt_size{};
@ -181,7 +181,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos > 0) {
if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation.
prompt_string = "<end_of_turn>\n" + prompt_string;