mirror of https://github.com/google/gemma.cpp.git
Merge 51a708e957 into 5bc356f18f
This commit is contained in:
commit
29e3a1bba9
|
|
@ -155,6 +155,15 @@ struct LayerConfig {
|
||||||
size_t conv1d_width = 0; // griffin only
|
size_t conv1d_width = 0; // griffin only
|
||||||
bool ff_biases = false;
|
bool ff_biases = false;
|
||||||
bool softmax_attn_output_biases = false;
|
bool softmax_attn_output_biases = false;
|
||||||
|
/**
|
||||||
|
* Self-extend
|
||||||
|
* Jin, Hongye, et al. "Llm maybe longlm: Self-extend llm context window without tuning." arXiv preprint arXiv:2401.01325 (2024).
|
||||||
|
*/
|
||||||
|
bool self_extend = false;
|
||||||
|
// Self-extend neighbor size
|
||||||
|
size_t se_neighbor_size = std::numeric_limits<size_t>::max();
|
||||||
|
// Self-extend group window size
|
||||||
|
size_t se_group_size = 1;
|
||||||
bool optimized_gating = true;
|
bool optimized_gating = true;
|
||||||
PostNormType post_norm = PostNormType::None;
|
PostNormType post_norm = PostNormType::None;
|
||||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||||
|
|
|
||||||
|
|
@ -300,6 +300,9 @@ class GemmaAttention {
|
||||||
}
|
}
|
||||||
} // !is_mha_
|
} // !is_mha_
|
||||||
|
|
||||||
|
// Self-extension
|
||||||
|
const hwy::Divisor div_grp_size(
|
||||||
|
static_cast<uint32_t>(layer_config_.se_group_size));
|
||||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||||
pool_.Run(0, kv_heads * num_interleaved,
|
pool_.Run(0, kv_heads * num_interleaved,
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
@ -307,21 +310,29 @@ class GemmaAttention {
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries_;
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||||
layer_ * cache_layer_size_ +
|
layer_ * cache_layer_size_ +
|
||||||
head * qkv_dim * 2;
|
head * qkv_dim * 2;
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
|
|
||||||
|
const size_t se_neighbor_size = layer_config_.se_neighbor_size;
|
||||||
|
const bool enable_self_extend = layer_config_.self_extend;
|
||||||
|
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
const float* HWY_RESTRICT mha_kv =
|
const float* HWY_RESTRICT mha_kv =
|
||||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||||
qkv_dim;
|
qkv_dim;
|
||||||
|
|
||||||
|
// In self-extend, when embedding position,
|
||||||
|
// we will use grouped key position
|
||||||
|
if (enable_self_extend && pos > se_neighbor_size) {
|
||||||
|
pos = div_grp_size.Divide(pos);
|
||||||
|
}
|
||||||
// Copy from `q` if MHA, or apply in-place.
|
// Copy from `q` if MHA, or apply in-place.
|
||||||
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
|
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
|
||||||
kv);
|
kv);
|
||||||
|
|
||||||
// If MHA, also copy V into KVCache.
|
// If MHA, also copy V into KVCache.
|
||||||
if (is_mha_) {
|
if (is_mha_) {
|
||||||
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
|
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
|
||||||
|
|
@ -405,12 +416,25 @@ class GemmaAttention {
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t head_offset =
|
const size_t head_offset =
|
||||||
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
|
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
|
||||||
|
|
||||||
|
const size_t se_group_size = layer_config_.se_group_size;
|
||||||
|
const size_t se_neighbor_size = layer_config_.se_neighbor_size;
|
||||||
|
const bool enable_self_extend =
|
||||||
|
layer_config_.self_extend;
|
||||||
|
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
|
if (enable_self_extend && pos > se_neighbor_size) {
|
||||||
|
const size_t grp_pos = pos / se_group_size;
|
||||||
|
const size_t shift =
|
||||||
|
se_neighbor_size - se_neighbor_size / se_group_size;
|
||||||
|
const size_t shifted_grouped_pos = grp_pos + shift;
|
||||||
|
pos = shifted_grouped_pos;
|
||||||
|
}
|
||||||
PositionalEncodingQK(q, pos, layer_, query_scale, q);
|
PositionalEncodingQK(q, pos, layer_, query_scale, q);
|
||||||
|
|
||||||
const size_t start_pos = StartPos(pos, layer_);
|
const size_t start_pos = StartPos(pos, layer_);
|
||||||
|
|
@ -1401,7 +1425,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
|
||||||
qbatch_size);
|
qbatch_size);
|
||||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||||
qbatch_size);
|
qbatch_size);
|
||||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||||
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
|
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
|
||||||
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);
|
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,7 @@ class Gemma {
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||||
|
ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); }
|
||||||
const ModelInfo& Info() const { return info_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ModelWeightsStorage& Weights() const { return model_; }
|
const ModelWeightsStorage& Weights() const { return model_; }
|
||||||
|
|
|
||||||
21
gemma/run.cc
21
gemma/run.cc
|
|
@ -77,6 +77,26 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
||||||
return prompt_string;
|
return prompt_string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract args from the loader and modify model config
|
||||||
|
void ApplySelfExtendIfGiven(Gemma& model, LoaderArgs loader) {
|
||||||
|
ModelConfig& config = model.GetMutableModelConfig();
|
||||||
|
if (loader.self_extend != Tristate::kTrue) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify layer config in-place
|
||||||
|
auto& layer_configs = config.layer_configs;
|
||||||
|
std::transform(layer_configs.begin(), layer_configs.end(), layer_configs.begin(),
|
||||||
|
[&loader](LayerConfig& layer_config) {
|
||||||
|
layer_config.self_extend =
|
||||||
|
loader.self_extend == Tristate::kTrue;
|
||||||
|
layer_config.se_group_size = loader.se_group_size;
|
||||||
|
layer_config.se_neighbor_size = loader.se_neighbor_size;
|
||||||
|
|
||||||
|
return layer_config;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
const InferenceArgs& args, const AcceptFunc& accept_token,
|
const InferenceArgs& args, const AcceptFunc& accept_token,
|
||||||
|
|
@ -243,6 +263,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
Allocator::Init(pools.Topology());
|
Allocator::Init(pools.Topology());
|
||||||
|
|
||||||
Gemma model = CreateGemma(loader, pools);
|
Gemma model = CreateGemma(loader, pools);
|
||||||
|
ApplySelfExtendIfGiven(model, loader);
|
||||||
KVCache kv_cache =
|
KVCache kv_cache =
|
||||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -536,6 +536,7 @@ class ModelWeightsStorage {
|
||||||
void CopyWithTranspose(hwy::ThreadPool& pool);
|
void CopyWithTranspose(hwy::ThreadPool& pool);
|
||||||
void LogWeightStats();
|
void LogWeightStats();
|
||||||
const ModelConfig& Config() const { return config_; }
|
const ModelConfig& Config() const { return config_; }
|
||||||
|
ModelConfig& MutableConfig() { return config_; }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ModelWeightsPtrs<T>* GetWeightsOfType() const {
|
ModelWeightsPtrs<T>* GetWeightsOfType() const {
|
||||||
|
|
|
||||||
11
util/app.h
11
util/app.h
|
|
@ -173,6 +173,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
std::string model_type_str;
|
std::string model_type_str;
|
||||||
std::string weight_type_str;
|
std::string weight_type_str;
|
||||||
|
|
||||||
|
// Self-extend
|
||||||
|
Tristate self_extend;
|
||||||
|
size_t se_group_size;
|
||||||
|
size_t se_neighbor_size;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
|
|
@ -191,6 +196,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
|
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
|
visitor(self_extend, "self_extend", Tristate::kDefault,
|
||||||
|
"Apply self extend ? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||||
|
visitor(se_group_size, "se_group_size", size_t{1}, "Group size for self extend");
|
||||||
|
visitor(se_neighbor_size, "se_neighbor_size",
|
||||||
|
std::numeric_limits<size_t>::max(),
|
||||||
|
"Neighbor window size for self extend");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
// Uninitialized before Validate, must call after that.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue