Added support for mutable ModelConfig, run.cc can support runtime self extend config

This commit is contained in:
Nanubala Gnana Sai 2024-11-19 22:33:27 +05:30
parent 397952f918
commit 14d62b0098
6 changed files with 56 additions and 12 deletions

View File

@ -135,9 +135,16 @@ struct LayerConfig {
size_t conv1d_width = 0;
bool ff_biases = false;
bool softmax_attn_output_biases = false;
/**
* Self-extend
* Jin, Hongye, et al. "Llm maybe longlm: Self-extend llm context window without tuning." arXiv preprint arXiv:2401.01325 (2024).
*/
bool self_extend = false;
size_t ngb_size = 0;
size_t grp_size = 1;
// Self-extend neighbor size
size_t se_neighbor_size = std::numeric_limits<size_t>::max();
// Self-extend group window size
size_t se_group_size = 1;
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;

View File

@ -302,7 +302,7 @@ class GemmaAttention {
// Self-extension
const hwy::Divisor div_grp_size(
static_cast<uint32_t>(layer_config_.grp_size));
static_cast<uint32_t>(layer_config_.se_group_size));
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
@ -317,8 +317,8 @@ class GemmaAttention {
head * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
const size_t se_neighbor_size = layer_config_.se_neighbor_size;
const bool enable_self_extend = layer_config_.self_extend;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
@ -327,7 +327,7 @@ class GemmaAttention {
// In self-extend, when embedding position,
// we will use grouped key position
if (self_extend && pos > ngb_size) {
if (enable_self_extend && pos > se_neighbor_size) {
pos = div_grp_size.Divide(pos);
}
// Copy from `q` if MHA, or apply in-place.
@ -417,18 +417,21 @@ class GemmaAttention {
const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
const size_t se_group_size = layer_config_.se_group_size;
const size_t se_neighbor_size = layer_config_.se_neighbor_size;
const bool enable_self_extend =
layer_config_.self_extend;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;
// Apply rope and scaling to Q.
size_t pos = queries_pos_[query_idx] + batch_idx;
if (self_extend && pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
if (enable_self_extend && pos > se_neighbor_size) {
const size_t grp_pos = pos / se_group_size;
const size_t shift =
se_neighbor_size - se_neighbor_size / se_group_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}

View File

@ -194,6 +194,7 @@ class Gemma {
~Gemma();
const ModelConfig& GetModelConfig() const { return model_.Config(); }
ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); }
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ModelWeightsStorage& Weights() const { return model_; }

View File

@ -77,6 +77,26 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string;
}
// Extract args from the loader and modify model config
void ApplySelfExtendIfGiven(Gemma& model, LoaderArgs loader) {
ModelConfig& config = model.GetMutableModelConfig();
if (loader.self_extend != Tristate::kTrue) {
return;
}
// Modify layer config in-place
auto& layer_configs = config.layer_configs;
std::transform(layer_configs.begin(), layer_configs.end(), layer_configs.begin(),
[&loader](LayerConfig& layer_config) {
layer_config.self_extend =
loader.self_extend == Tristate::kTrue;
layer_config.se_group_size = loader.se_group_size;
layer_config.se_neighbor_size = loader.se_neighbor_size;
return layer_config;
});
}
// The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
const InferenceArgs& args, const AcceptFunc& accept_token,
@ -206,6 +226,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
Allocator::Init(pools.Topology());
Gemma model = CreateGemma(loader, pools);
ApplySelfExtendIfGiven(model, loader);
KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);

View File

@ -550,6 +550,7 @@ class ModelWeightsStorage {
void CopyWithTranspose(hwy::ThreadPool& pool);
void LogWeightStats();
const ModelConfig& Config() const { return config_; }
ModelConfig& MutableConfig() { return config_; }
template <typename T>
ModelWeightsPtrs<T>* GetWeightsOfType() const {

View File

@ -171,6 +171,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
std::string model_type_str;
std::string weight_type_str;
// Self-extend
Tristate self_extend;
size_t se_group_size;
size_t se_neighbor_size;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
@ -189,6 +194,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
" Required argument.");
visitor(self_extend, "self_extend", Tristate::kDefault,
"Apply self extend ? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(se_group_size, "se_group_size", size_t{1}, "Group size for self extend");
visitor(se_neighbor_size, "se_neighbor_size",
std::numeric_limits<size_t>::max(),
"Neighbor window size for self extend");
}
// Uninitialized before Validate, must call after that.