diff --git a/gemma/configs.cc b/gemma/configs.cc index 613b31a..c01ec88 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -572,7 +572,8 @@ ModelConfig::ModelConfig(const Model model, Type weight, static Model FindModel(const std::string& specifier) { Model found_model = Model::UNKNOWN; ForEachModel([&](Model model) { - const char* prefix = ModelPrefix(model); + // Some model names are prefixes of other model names + const std::string prefix = std::string(ModelPrefix(model)) + "-"; if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix. // We only expect one match. HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str()); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index dae2294..bd80e17 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -176,7 +176,7 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, Sigmoid(gate_x + head_offset, kHeadDim); Sigmoid(a + head_offset, kHeadDim); const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; + HWY_ATTR { return hn::Mul(x, gate_x); }; hn::Transform1(D(), a + head_offset, kHeadDim, layer_weights->griffin.a.PackedScale1() + head_offset, fn_mul); @@ -424,51 +424,49 @@ class GemmaAttention { const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads; // For each head (token, query), compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t interleaved_idx = task / layer_config_.heads; - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; + pool_.Run( + 0, layer_config_.heads * num_interleaved, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % layer_config_.heads; + const size_t interleaved_idx = task / layer_config_.heads; + const size_t query_idx = interleaved_idx % num_queries_; + const size_t batch_idx = interleaved_idx / num_queries_; + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; - float* HWY_RESTRICT q = - activations_.q.Row(interleaved_idx) + head * q_stride_; - float* HWY_RESTRICT att = - activations_.att.Row(interleaved_idx) + - head * activations_.seq_len; - float* HWY_RESTRICT att_out = - activations_.att_out.Row(interleaved_idx) + head * qkv_dim; + float* HWY_RESTRICT q = + activations_.q.Row(interleaved_idx) + head * q_stride_; + float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) + + head * activations_.seq_len; + float* HWY_RESTRICT att_out = + activations_.att_out.Row(interleaved_idx) + head * qkv_dim; - // Make strided views into the kv cache entries for the current - // query and head. - KVCache& kv_cache = kv_caches_[query_idx]; - const size_t kv_head_offset = - layer_ * cache_layer_size_ + head_offset; - MatPtrT k("k_view", - Extents2D(kv_cache.seq_len, qkv_dim)); - k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset, - /*stride=*/cache_pos_size_); - MatPtrT v("v_view", - Extents2D(kv_cache.seq_len, qkv_dim)); - v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, - /*stride=*/cache_pos_size_); + // Make strided views into the kv cache entries for the current + // query and head. + KVCache& kv_cache = kv_caches_[query_idx]; + const size_t kv_head_offset = + layer_ * cache_layer_size_ + head_offset; + MatPtrT k("k_view", Extents2D(kv_cache.seq_len, qkv_dim)); + k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset, + /*stride=*/cache_pos_size_); + MatPtrT v("v_view", Extents2D(kv_cache.seq_len, qkv_dim)); + v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, + /*stride=*/cache_pos_size_); - // Find the token position in the query and calculate the range - // of cache positions to attend to. - const size_t pos = queries_pos_[query_idx] + batch_idx; - const size_t start_pos = StartPos(pos, layer_); - size_t last_pos = pos; - const size_t prefix_end = queries_prefix_end_[query_idx]; - if (prefix_end > 0 && prefix_end - 1 > last_pos) { - // last_pos in QDotK and WeightedSumV is inclusive. - last_pos = prefix_end - 1; - } + // Find the token position in the query and calculate the range + // of cache positions to attend to. + const size_t pos = queries_pos_[query_idx] + batch_idx; + const size_t start_pos = StartPos(pos, layer_); + size_t last_pos = pos; + const size_t prefix_end = queries_prefix_end_[query_idx]; + if (prefix_end > 0 && prefix_end - 1 > last_pos) { + // last_pos in QDotK and WeightedSumV is inclusive. + last_pos = prefix_end - 1; + } - SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, - pos, start_pos, last_pos); - }); + SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos, + start_pos, last_pos); + }); } private: @@ -1510,8 +1508,7 @@ void GenerateSingleT(const ModelStore& model, } template -void GenerateBatchT(const ModelStore& model, - const ModelWeightsPtrs& weights, +void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs& weights, const RuntimeConfig& runtime_config, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, @@ -1536,7 +1533,7 @@ void GenerateBatchT(const ModelStore& model, qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); + qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT(model, weights, activations, runtime_config, qbatch_prompts, qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,