mirror of https://github.com/google/gemma.cpp.git
306 lines
12 KiB
C++
306 lines
12 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Shared between various frontends.
|
|
|
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
|
|
|
#include <stddef.h>
|
|
#include <stdio.h>
|
|
|
|
#include <functional>
|
|
#include <string>
|
|
|
|
#include "gemma/configs.h"
|
|
#include "io/io.h" // Path
|
|
#include "util/args.h"
|
|
#include "util/basics.h" // Tristate
|
|
#include "util/mat.h"
|
|
#include "hwy/aligned_allocator.h" // Span
|
|
#include "hwy/base.h" // HWY_ABORT
|
|
#include "hwy/profiler.h"
|
|
|
|
namespace gcpp {
|
|
|
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
LoaderArgs(const std::string& tokenizer_path,
|
|
const std::string& weights_path) {
|
|
Init(); // Init sets to defaults, so assignments must come after Init().
|
|
tokenizer.path = tokenizer_path;
|
|
weights.path = weights_path;
|
|
};
|
|
|
|
Path tokenizer;
|
|
Path weights; // weights file location
|
|
Tristate map;
|
|
Tristate to_bf16;
|
|
Tristate wrapping;
|
|
|
|
template <class Visitor>
|
|
void ForEach(const Visitor& visitor) {
|
|
visitor(tokenizer, "tokenizer", Path(),
|
|
"Path name of tokenizer model; only required for pre-2025 format.");
|
|
visitor(weights, "weights", Path(),
|
|
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
|
visitor(map, "map", Tristate::kDefault,
|
|
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
|
|
visitor(to_bf16, "to_bf16", Tristate::kDefault,
|
|
"Convert weights to bf16? -1 = auto, 0 = no, 1 = yes.");
|
|
visitor(wrapping, "wrapping", Tristate::kDefault,
|
|
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
|
|
}
|
|
};
|
|
|
|
using PromptTokens = hwy::Span<const int>;
|
|
|
|
// Batches of independent queries have their own prompt, previous token,
|
|
// position in the sequence, and KVCache.
|
|
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
|
using QueriesToken = hwy::Span<const int>;
|
|
using QueriesPos = hwy::Span<const size_t>;
|
|
|
|
// ImageTokens are represented as a matrix, where each row corresponds
|
|
// to a token for an image patch as computed by the image encoder.
|
|
using ImageTokens = MatStorageT<float>;
|
|
|
|
// StreamFunc is called with (token, probability). For prompt tokens,
|
|
// probability is 0.0f. StreamFunc should return false to stop generation and
|
|
// true to continue generation.
|
|
using StreamFunc = std::function<bool(int, float)>;
|
|
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
|
// For prompt tokens, probability is 0.0f. Generation continues if this returns
|
|
// true and stops if it returns false. Note that query_idx is absolute, not
|
|
// relative to the batch.
|
|
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
|
// If not empty, AcceptFunc is called with token. It should return false for
|
|
// tokens you don't want to generate and true for tokens you want to generate.
|
|
using AcceptFunc = std::function<bool(int, float)>;
|
|
// If not empty, SampleFunc is called concurrently from worker thread(s) with
|
|
// query_idx, pos, logits for the next token (which it may modify/overwrite),
|
|
// and worker. It returns the next generated token and its probability.
|
|
using SampleFunc = std::function<TokenAndProb(size_t, size_t, Logits, size_t)>;
|
|
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
|
// - index of query within containing batch (if any); zero otherwise.
|
|
// - position in the tokens sequence
|
|
// - name of the data, e.g. "tokens" for token IDs
|
|
// - layer index (or -1 for global outputs)
|
|
// - pointer to the data array
|
|
// - size of the data array
|
|
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
|
int, const float*, size_t)>;
|
|
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
|
// - per-query position within the tokens sequence
|
|
// - layer index (or -1 for post-norm output)
|
|
// - activations
|
|
struct Activations;
|
|
using ActivationsObserverFunc =
|
|
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
|
|
|
// RuntimeConfig holds configuration for a single generation run.
|
|
// TODO: move into InferenceArgs, use that directly.
|
|
struct RuntimeConfig {
|
|
// If non-null, `batch_stream_token` is called for each token in the batch,
|
|
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
|
// This is called sequentially from the main thread.
|
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
|
PROFILER_ZONE("Gen.StreamToken");
|
|
if (batch_stream_token) {
|
|
return batch_stream_token(query_idx, pos, token, prob);
|
|
}
|
|
return stream_token(token, prob);
|
|
}
|
|
|
|
// Limit on the number of tokens generated.
|
|
size_t max_generated_tokens;
|
|
|
|
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
|
// Max tokens per batch during prefill.
|
|
size_t prefill_tbatch_size = 256;
|
|
// Max queries per batch (one token from each) during decode.
|
|
size_t decode_qbatch_size = 16;
|
|
|
|
// Sampling-related parameters.
|
|
float temperature; // Temperature for sampling.
|
|
|
|
size_t top_k = 1; // Top-k for sampling.
|
|
|
|
int verbosity; // Controls verbosity of printed messages.
|
|
|
|
// Which attention implementation to use.
|
|
AttentionImpl attention_impl = AttentionImpl::kFlash;
|
|
|
|
// Functions operating on the generated tokens.
|
|
StreamFunc stream_token;
|
|
BatchStreamFunc batch_stream_token;
|
|
AcceptFunc accept_token; // if empty, accepts all tokens.
|
|
SampleFunc sample_func; // if empty, uses SampleTopK.
|
|
|
|
// Observer callbacks for intermediate data.
|
|
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
|
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
|
|
|
// If not empty, these point to the image tokens and are used in the
|
|
// PaliGemma prefix-LM style attention.
|
|
const ImageTokens* image_tokens = nullptr;
|
|
|
|
// Whether to use thread spinning to reduce barrier synchronization latency.
|
|
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
|
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
|
// default decision is likely sufficient because it is based on whether
|
|
// threads are successfully pinned.
|
|
mutable Tristate use_spinning = Tristate::kDefault;
|
|
|
|
// Whether to use continuous batching.
|
|
bool use_continuous_batching = false;
|
|
};
|
|
|
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
InferenceArgs() { Init(); };
|
|
|
|
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
|
|
|
|
int verbosity;
|
|
|
|
size_t seq_len;
|
|
size_t max_generated_tokens;
|
|
|
|
size_t prefill_tbatch_size;
|
|
size_t decode_qbatch_size;
|
|
|
|
float temperature;
|
|
size_t top_k;
|
|
bool deterministic;
|
|
bool multiturn;
|
|
Path image_file;
|
|
|
|
int port; // Server port
|
|
std::string model; // Model name for API endpoints
|
|
std::string prompt; // Bypasses std::getline
|
|
// For prompts longer than the Linux terminal's 4K line edit buffer.
|
|
Path prompt_file;
|
|
std::string eot_line;
|
|
|
|
template <class Visitor>
|
|
void ForEach(const Visitor& visitor) {
|
|
visitor(verbosity, "verbosity", 1,
|
|
"Show verbose developer information\n 0 = only print generation "
|
|
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
|
"developer/debug info).\n Default = 1.",
|
|
1);
|
|
|
|
visitor(seq_len, "seq_len", size_t{8192},
|
|
"Sequence length, capped by ModelConfig.max_seq_len.");
|
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
|
|
"Maximum number of tokens to generate.");
|
|
|
|
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
|
"Prefill: max tokens per batch.");
|
|
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
|
"Decode: max queries per batch.");
|
|
|
|
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
|
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
|
2);
|
|
visitor(deterministic, "deterministic", false,
|
|
"Make top-k sampling deterministic", 2);
|
|
visitor(multiturn, "multiturn", false,
|
|
"Multiturn mode\n 0 = clear KV cache after every "
|
|
"interaction\n 1 = continue KV cache after every interaction\n "
|
|
" Default : 0 (conversation "
|
|
"resets every turn)");
|
|
visitor(image_file, "image_file", Path(), "Image file to load.");
|
|
|
|
// Since it is not used in the CLI version, the print_verbosity is set
|
|
// higher than others.
|
|
visitor(port, "port", 8080, "Server port (default: 8080)", 3);
|
|
visitor(model, "model", std::string("gemma3-4b"),
|
|
"Model name for API endpoints (default: gemma3-4b)", 3);
|
|
|
|
visitor(prompt, "prompt", std::string(""),
|
|
"Initial prompt for non-interactive mode. When specified, "
|
|
"generates a response and exits.",
|
|
1);
|
|
visitor(prompt_file, "prompt_file", Path(),
|
|
"Path to file containing the prompt for non-interactive mode. When "
|
|
" specified, generates a response and exits.",
|
|
1);
|
|
|
|
visitor(
|
|
eot_line, "eot_line", std::string(""),
|
|
"End of turn line. "
|
|
"When you specify this, the prompt will be all lines "
|
|
"before the line where only the given string appears.\n Default = "
|
|
"When a newline is encountered, that signals the end of the turn.",
|
|
2);
|
|
}
|
|
|
|
void CopyTo(RuntimeConfig& runtime_config) const {
|
|
runtime_config.max_generated_tokens = max_generated_tokens;
|
|
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
|
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
|
if (prefill_tbatch_size > kMaxBatchSize) {
|
|
HWY_ABORT(
|
|
"prefill_tbatch_size %zu > kMaxBatchSize %zu: specify a "
|
|
"smaller value, or increase kMaxBatchSize.\n",
|
|
prefill_tbatch_size, kMaxBatchSize);
|
|
}
|
|
if (decode_qbatch_size > kMaxBatchSize) {
|
|
HWY_ABORT(
|
|
"decode_qbatch_size %zu > kMaxBatchSize %zu: specify a "
|
|
"smaller value, or increase kMaxBatchSize.\n",
|
|
decode_qbatch_size, kMaxBatchSize);
|
|
}
|
|
|
|
runtime_config.temperature = temperature;
|
|
runtime_config.top_k = top_k;
|
|
}
|
|
};
|
|
|
|
struct ClientArgs : public ArgsBase<ClientArgs> {
|
|
ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
ClientArgs() { Init(); };
|
|
|
|
std::string host;
|
|
int port;
|
|
std::string api_key;
|
|
std::string model;
|
|
std::string prompt;
|
|
bool interactive;
|
|
|
|
template <class Visitor>
|
|
void ForEach(const Visitor& visitor) {
|
|
visitor(host, "host", std::string("localhost"),
|
|
"Server host (default: localhost)");
|
|
visitor(port, "port", 8080,
|
|
"Server port (default: 8080)");
|
|
visitor(api_key, "api_key", std::string(""),
|
|
"Use public API with key (changes host to "
|
|
"generativelanguage.googleapis.com:443)");
|
|
visitor(model, "model", std::string("gemma3-4b"),
|
|
"Model name to use (default: gemma3-4b)");
|
|
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
|
"Prompt for generation (default: 'Hello! How are you?')");
|
|
visitor(interactive, "interactive", false,
|
|
"Start interactive chat mode (0 = no, 1 = yes)");
|
|
}
|
|
};
|
|
|
|
} // namespace gcpp
|
|
|
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|