[WIP] simplify hello world example, add convenience function. TODO: update git hash in CMakeLists.txt of hello world after push

This commit is contained in:
austinvhuang 2024-03-08 14:55:35 -05:00
parent b67e28d1a0
commit 42e53e2da8
3 changed files with 39 additions and 22 deletions

View File

@ -21,20 +21,27 @@ std::vector<int> tokenize(
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
// A rough heuristic for a reasonable number of threads given hardware
// concurrency estimate
// A rough heuristic number of threads to use
size_t num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
hwy::ThreadPool pool(num_threads);
hwy::ThreadPool inner_pool(0);
// Instantiate model
gcpp::Gemma model(loader, pool);
// Setup random number generator
std::mt19937 gen;
std::random_device rd;
gen.seed(rd());
// Tokenize instruction
std::vector<int> tokens =
tokenize("Write a greeting to the world.", model.Tokenizer());
size_t ntokens = tokens.size();
size_t pos = 0;
// Callback
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
int token, float) {
++pos;
@ -43,17 +50,16 @@ int main(int argc, char** argv) {
} else if (token != gcpp::EOS_ID) {
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
if (pos == ntokens + 1) {
// first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n\n"));
}
std::cout << token_text << std::flush;
}
return true;
};
GenerateGemma(
model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024,
/*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token,
[](int) { return true; }, gen, 0);
GenerateGemma(model,
{.max_tokens = 2048,
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0},
tokens, /*KV cache position = */ 0, pool, stream_token, gen);
std::cout << std::endl;
}

View File

@ -844,5 +844,16 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen) {
hwy::ThreadPool inner_pool(0);
GenerateGemma(
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
}
} // namespace gcpp
#endif // HWY_ONCE

22
gemma.h
View File

@ -64,17 +64,11 @@ struct KVCache {
enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// TODO: Incorporate this
struct Runtime {
// TODO: In the future we may fold ModelTraining into Model.
// As we add more variations of model_type, the cartesian set becomes
// unwieldy.
Model model_type;
ModelTraining model_training;
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
std::mt19937 gen;
int verbosity;
};
struct LoaderArgs : public ArgsBase<LoaderArgs> {
@ -205,9 +199,6 @@ struct Gemma {
gcpp::ModelTraining model_training;
};
struct LoaderArgs; // forward declaration
void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model);
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
@ -223,6 +214,15 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity);
// Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig
// - No threadpools within threadpools (inner_pool = dummy)
// - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen);
constexpr int EOS_ID = 1;
} // namespace gcpp