mirror of https://github.com/google/gemma.cpp.git
Fix sign comparison warnings
PiperOrigin-RevId: 627299902
This commit is contained in:
parent
ca971ef50f
commit
3bf22abb22
|
|
@ -1136,8 +1136,8 @@ template <class TConfig>
|
||||||
void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
||||||
size_t k) {
|
size_t k) {
|
||||||
std::vector<std::pair<float, int>> sorted(len);
|
std::vector<std::pair<float, int>> sorted(len);
|
||||||
for (int i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
sorted[i] = std::make_pair(dist[i], i);
|
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
||||||
}
|
}
|
||||||
std::sort(sorted.begin(), sorted.end(),
|
std::sort(sorted.begin(), sorted.end(),
|
||||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||||
|
|
@ -1146,9 +1146,10 @@ void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
||||||
}
|
}
|
||||||
return a.second < b.second;
|
return a.second < b.second;
|
||||||
});
|
});
|
||||||
for (int i = 0; i < k; ++i) {
|
for (size_t i = 0; i < k; ++i) {
|
||||||
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", i + 1, sorted[i].second,
|
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
|
||||||
TOKEN(sorted[i].second), sorted[i].first, logits[sorted[i].second]);
|
sorted[i].second, TOKEN(sorted[i].second), sorted[i].first,
|
||||||
|
logits[sorted[i].second]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||||
std::string& eot_line) {
|
std::string& eot_line) {
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
int abs_pos = 0; // absolute token index over all turns
|
size_t abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
int prompt_size{};
|
int prompt_size{};
|
||||||
|
|
||||||
|
|
@ -181,7 +181,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
// For instruction-tuned models: add control tokens.
|
// For instruction-tuned models: add control tokens.
|
||||||
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
prompt_string = "<start_of_turn>user\n" + prompt_string +
|
||||||
"<end_of_turn>\n<start_of_turn>model\n";
|
"<end_of_turn>\n<start_of_turn>model\n";
|
||||||
if (abs_pos > 0) {
|
if (abs_pos != 0) {
|
||||||
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
|
||||||
// continuation.
|
// continuation.
|
||||||
prompt_string = "<end_of_turn>\n" + prompt_string;
|
prompt_string = "<end_of_turn>\n" + prompt_string;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue