mirror of https://github.com/google/gemma.cpp.git
parent
1e8642f8f4
commit
9c3e089b09
|
|
@ -572,7 +572,8 @@ ModelConfig::ModelConfig(const Model model, Type weight,
|
|||
static Model FindModel(const std::string& specifier) {
|
||||
Model found_model = Model::UNKNOWN;
|
||||
ForEachModel([&](Model model) {
|
||||
const char* prefix = ModelPrefix(model);
|
||||
// Some model names are prefixes of other model names
|
||||
const std::string prefix = std::string(ModelPrefix(model)) + "-";
|
||||
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
|
||||
// We only expect one match.
|
||||
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
|
||||
|
|
|
|||
|
|
@ -424,7 +424,8 @@ class GemmaAttention {
|
|||
const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads;
|
||||
|
||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, layer_config_.heads * num_interleaved,
|
||||
pool_.Run(
|
||||
0, layer_config_.heads * num_interleaved,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t interleaved_idx = task / layer_config_.heads;
|
||||
|
|
@ -435,8 +436,7 @@ class GemmaAttention {
|
|||
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Row(interleaved_idx) + head * q_stride_;
|
||||
float* HWY_RESTRICT att =
|
||||
activations_.att.Row(interleaved_idx) +
|
||||
float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) +
|
||||
head * activations_.seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
|
||||
|
|
@ -446,12 +446,10 @@ class GemmaAttention {
|
|||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
const size_t kv_head_offset =
|
||||
layer_ * cache_layer_size_ + head_offset;
|
||||
MatPtrT<float> k("k_view",
|
||||
Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
||||
/*stride=*/cache_pos_size_);
|
||||
MatPtrT<float> v("v_view",
|
||||
Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||
/*stride=*/cache_pos_size_);
|
||||
|
||||
|
|
@ -466,8 +464,8 @@ class GemmaAttention {
|
|||
last_pos = prefix_end - 1;
|
||||
}
|
||||
|
||||
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
||||
pos, start_pos, last_pos);
|
||||
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos,
|
||||
start_pos, last_pos);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -1510,8 +1508,7 @@ void GenerateSingleT(const ModelStore& model,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateBatchT(const ModelStore& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
|
|
|
|||
Loading…
Reference in New Issue