Fix sign comparison warnings

PiperOrigin-RevId: 627299902
This commit is contained in:
Jan Wassenberg 2024-04-23 01:16:15 -07:00 committed by Copybara-Service
parent ca971ef50f
commit 3bf22abb22
2 changed files with 8 additions and 7 deletions

View File

@ -1136,8 +1136,8 @@ template <class TConfig>
void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len, void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
size_t k) { size_t k) {
std::vector<std::pair<float, int>> sorted(len); std::vector<std::pair<float, int>> sorted(len);
for (int i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], i); sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
} }
std::sort(sorted.begin(), sorted.end(), std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) { [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
@ -1146,9 +1146,10 @@ void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
} }
return a.second < b.second; return a.second < b.second;
}); });
for (int i = 0; i < k; ++i) { for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", i + 1, sorted[i].second, printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
TOKEN(sorted[i].second), sorted[i].first, logits[sorted[i].second]); sorted[i].second, TOKEN(sorted[i].second), sorted[i].first,
logits[sorted[i].second]);
} }
} }

View File

@ -98,7 +98,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
int verbosity, const gcpp::AcceptFunc& accept_token, int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) { std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
int prompt_size{}; int prompt_size{};
@ -181,7 +181,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// For instruction-tuned models: add control tokens. // For instruction-tuned models: add control tokens.
prompt_string = "<start_of_turn>user\n" + prompt_string + prompt_string = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n"; "<end_of_turn>\n<start_of_turn>model\n";
if (abs_pos > 0) { if (abs_pos != 0) {
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue // Prepend "<end_of_turn>" token if this is a multi-turn dialogue
// continuation. // continuation.
prompt_string = "<end_of_turn>\n" + prompt_string; prompt_string = "<end_of_turn>\n" + prompt_string;