mirror of https://github.com/google/gemma.cpp.git
[WIP] simplify hello world example, add convenience function. TODO: update git hash in CMakeLists.txt of hello world after push
This commit is contained in:
parent
b67e28d1a0
commit
42e53e2da8
|
|
@ -21,20 +21,27 @@ std::vector<int> tokenize(
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
// A rough heuristic for a reasonable number of threads given hardware
|
||||
// concurrency estimate
|
||||
|
||||
// A rough heuristic number of threads to use
|
||||
size_t num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
hwy::ThreadPool pool(num_threads);
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
|
||||
// Instantiate model
|
||||
gcpp::Gemma model(loader, pool);
|
||||
|
||||
// Setup random number generator
|
||||
std::mt19937 gen;
|
||||
std::random_device rd;
|
||||
gen.seed(rd());
|
||||
|
||||
// Tokenize instruction
|
||||
std::vector<int> tokens =
|
||||
tokenize("Write a greeting to the world.", model.Tokenizer());
|
||||
size_t ntokens = tokens.size();
|
||||
size_t pos = 0;
|
||||
|
||||
// Callback
|
||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
||||
int token, float) {
|
||||
++pos;
|
||||
|
|
@ -43,17 +50,16 @@ int main(int argc, char** argv) {
|
|||
} else if (token != gcpp::EOS_ID) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
|
||||
if (pos == ntokens + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n\n"));
|
||||
}
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
GenerateGemma(
|
||||
model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024,
|
||||
/*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token,
|
||||
[](int) { return true; }, gen, 0);
|
||||
|
||||
GenerateGemma(model,
|
||||
{.max_tokens = 2048,
|
||||
.max_generated_tokens = 1024,
|
||||
.temperature = 1.0,
|
||||
.verbosity = 0},
|
||||
tokens, /*KV cache position = */ 0, pool, stream_token, gen);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
|
|
|||
11
gemma.cc
11
gemma.cc
|
|
@ -844,5 +844,16 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen) {
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
GenerateGemma(
|
||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
22
gemma.h
22
gemma.h
|
|
@ -64,17 +64,11 @@ struct KVCache {
|
|||
enum class Model { GEMMA_2B, GEMMA_7B };
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
// TODO: Incorporate this
|
||||
struct Runtime {
|
||||
// TODO: In the future we may fold ModelTraining into Model.
|
||||
// As we add more variations of model_type, the cartesian set becomes
|
||||
// unwieldy.
|
||||
Model model_type;
|
||||
ModelTraining model_training;
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
float temperature;
|
||||
std::mt19937 gen;
|
||||
int verbosity;
|
||||
};
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
|
|
@ -205,9 +199,6 @@ struct Gemma {
|
|||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
struct LoaderArgs; // forward declaration
|
||||
void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model);
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
|
||||
|
|
@ -223,6 +214,15 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity);
|
||||
|
||||
// Convenience function for the common case:
|
||||
// - Bundle runtime parameters as RuntimeConfig
|
||||
// - No threadpools within threadpools (inner_pool = dummy)
|
||||
// - All tokens accepted
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue