mirror of https://github.com/google/gemma.cpp.git
parent
1e8642f8f4
commit
9c3e089b09
|
|
@ -572,7 +572,8 @@ ModelConfig::ModelConfig(const Model model, Type weight,
|
||||||
static Model FindModel(const std::string& specifier) {
|
static Model FindModel(const std::string& specifier) {
|
||||||
Model found_model = Model::UNKNOWN;
|
Model found_model = Model::UNKNOWN;
|
||||||
ForEachModel([&](Model model) {
|
ForEachModel([&](Model model) {
|
||||||
const char* prefix = ModelPrefix(model);
|
// Some model names are prefixes of other model names
|
||||||
|
const std::string prefix = std::string(ModelPrefix(model)) + "-";
|
||||||
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
|
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
|
||||||
// We only expect one match.
|
// We only expect one match.
|
||||||
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
|
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,7 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
||||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||||
Sigmoid(a + head_offset, kHeadDim);
|
Sigmoid(a + head_offset, kHeadDim);
|
||||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||||
layer_weights->griffin.a.PackedScale1() + head_offset,
|
layer_weights->griffin.a.PackedScale1() + head_offset,
|
||||||
fn_mul);
|
fn_mul);
|
||||||
|
|
@ -424,51 +424,49 @@ class GemmaAttention {
|
||||||
const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads;
|
const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads;
|
||||||
|
|
||||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||||
pool_.Run(0, layer_config_.heads * num_interleaved,
|
pool_.Run(
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
0, layer_config_.heads * num_interleaved,
|
||||||
const size_t head = task % layer_config_.heads;
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t interleaved_idx = task / layer_config_.heads;
|
const size_t head = task % layer_config_.heads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries_;
|
const size_t interleaved_idx = task / layer_config_.heads;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
|
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||||
|
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.q.Row(interleaved_idx) + head * q_stride_;
|
activations_.q.Row(interleaved_idx) + head * q_stride_;
|
||||||
float* HWY_RESTRICT att =
|
float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) +
|
||||||
activations_.att.Row(interleaved_idx) +
|
head * activations_.seq_len;
|
||||||
head * activations_.seq_len;
|
float* HWY_RESTRICT att_out =
|
||||||
float* HWY_RESTRICT att_out =
|
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
|
||||||
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
|
|
||||||
|
|
||||||
// Make strided views into the kv cache entries for the current
|
// Make strided views into the kv cache entries for the current
|
||||||
// query and head.
|
// query and head.
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
const size_t kv_head_offset =
|
const size_t kv_head_offset =
|
||||||
layer_ * cache_layer_size_ + head_offset;
|
layer_ * cache_layer_size_ + head_offset;
|
||||||
MatPtrT<float> k("k_view",
|
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
||||||
Extents2D(kv_cache.seq_len, qkv_dim));
|
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
||||||
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
/*stride=*/cache_pos_size_);
|
||||||
/*stride=*/cache_pos_size_);
|
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
||||||
MatPtrT<float> v("v_view",
|
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||||
Extents2D(kv_cache.seq_len, qkv_dim));
|
/*stride=*/cache_pos_size_);
|
||||||
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
|
||||||
/*stride=*/cache_pos_size_);
|
|
||||||
|
|
||||||
// Find the token position in the query and calculate the range
|
// Find the token position in the query and calculate the range
|
||||||
// of cache positions to attend to.
|
// of cache positions to attend to.
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
const size_t start_pos = StartPos(pos, layer_);
|
const size_t start_pos = StartPos(pos, layer_);
|
||||||
size_t last_pos = pos;
|
size_t last_pos = pos;
|
||||||
const size_t prefix_end = queries_prefix_end_[query_idx];
|
const size_t prefix_end = queries_prefix_end_[query_idx];
|
||||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||||
last_pos = prefix_end - 1;
|
last_pos = prefix_end - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos,
|
||||||
pos, start_pos, last_pos);
|
start_pos, last_pos);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -1510,8 +1508,7 @@ void GenerateSingleT(const ModelStore& model,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void GenerateBatchT(const ModelStore& model,
|
void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
||||||
const ModelWeightsPtrs<T>& weights,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const QueriesPromptTokens& queries_prompt,
|
const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesPos& queries_pos,
|
const QueriesPos& queries_pos,
|
||||||
|
|
@ -1536,7 +1533,7 @@ void GenerateBatchT(const ModelStore& model,
|
||||||
qbatch_size);
|
qbatch_size);
|
||||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||||
qbatch_size);
|
qbatch_size);
|
||||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||||
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
|
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
|
||||||
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
|
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue