mirror of https://github.com/google/gemma.cpp.git
[WIP] simplify hello world example, add convenience function. TODO: update git hash in CMakeLists.txt of hello world after push
This commit is contained in:
parent
b67e28d1a0
commit
42e53e2da8
|
|
@ -21,20 +21,27 @@ std::vector<int> tokenize(
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
// A rough heuristic for a reasonable number of threads given hardware
|
|
||||||
// concurrency estimate
|
// A rough heuristic number of threads to use
|
||||||
size_t num_threads = static_cast<size_t>(std::clamp(
|
size_t num_threads = static_cast<size_t>(std::clamp(
|
||||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||||
hwy::ThreadPool pool(num_threads);
|
hwy::ThreadPool pool(num_threads);
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
|
// Instantiate model
|
||||||
gcpp::Gemma model(loader, pool);
|
gcpp::Gemma model(loader, pool);
|
||||||
|
|
||||||
|
// Setup random number generator
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
gen.seed(rd());
|
gen.seed(rd());
|
||||||
|
|
||||||
|
// Tokenize instruction
|
||||||
std::vector<int> tokens =
|
std::vector<int> tokens =
|
||||||
tokenize("Write a greeting to the world.", model.Tokenizer());
|
tokenize("Write a greeting to the world.", model.Tokenizer());
|
||||||
size_t ntokens = tokens.size();
|
size_t ntokens = tokens.size();
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
|
|
||||||
|
// Callback
|
||||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
||||||
int token, float) {
|
int token, float) {
|
||||||
++pos;
|
++pos;
|
||||||
|
|
@ -43,17 +50,16 @@ int main(int argc, char** argv) {
|
||||||
} else if (token != gcpp::EOS_ID) {
|
} else if (token != gcpp::EOS_ID) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
||||||
if (pos == ntokens + 1) {
|
|
||||||
// first token of response
|
|
||||||
token_text.erase(0, token_text.find_first_not_of(" \t\n\n"));
|
|
||||||
}
|
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
GenerateGemma(
|
|
||||||
model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024,
|
GenerateGemma(model,
|
||||||
/*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token,
|
{.max_tokens = 2048,
|
||||||
[](int) { return true; }, gen, 0);
|
.max_generated_tokens = 1024,
|
||||||
|
.temperature = 1.0,
|
||||||
|
.verbosity = 0},
|
||||||
|
tokens, /*KV cache position = */ 0, pool, stream_token, gen);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
11
gemma.cc
11
gemma.cc
|
|
@ -844,5 +844,16 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
const StreamFunc& stream_token, std::mt19937& gen) {
|
||||||
|
hwy::ThreadPool inner_pool(0);
|
||||||
|
GenerateGemma(
|
||||||
|
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||||
|
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
|
||||||
|
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
22
gemma.h
22
gemma.h
|
|
@ -64,17 +64,11 @@ struct KVCache {
|
||||||
enum class Model { GEMMA_2B, GEMMA_7B };
|
enum class Model { GEMMA_2B, GEMMA_7B };
|
||||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
|
|
||||||
// TODO: Incorporate this
|
struct RuntimeConfig {
|
||||||
struct Runtime {
|
|
||||||
// TODO: In the future we may fold ModelTraining into Model.
|
|
||||||
// As we add more variations of model_type, the cartesian set becomes
|
|
||||||
// unwieldy.
|
|
||||||
Model model_type;
|
|
||||||
ModelTraining model_training;
|
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
float temperature;
|
float temperature;
|
||||||
std::mt19937 gen;
|
int verbosity;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
|
|
@ -205,9 +199,6 @@ struct Gemma {
|
||||||
gcpp::ModelTraining model_training;
|
gcpp::ModelTraining model_training;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LoaderArgs; // forward declaration
|
|
||||||
void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model);
|
|
||||||
|
|
||||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||||
|
|
||||||
|
|
@ -223,6 +214,15 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity);
|
int verbosity);
|
||||||
|
|
||||||
|
// Convenience function for the common case:
|
||||||
|
// - Bundle runtime parameters as RuntimeConfig
|
||||||
|
// - No threadpools within threadpools (inner_pool = dummy)
|
||||||
|
// - All tokens accepted
|
||||||
|
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
const StreamFunc& stream_token, std::mt19937& gen);
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue