tranpose
This commit is contained in:
parent
8666c546f9
commit
4be60a28b6
|
|
@ -2913,7 +2913,7 @@ struct ggml_cplan ggml_graph_plan(
|
|||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const int64_t S_v = node->src[0]->ne[0];
|
||||
cur = 4 * S_v * sizeof(float) * n_tasks;
|
||||
cur = (S_v * S_v + 4 * S_v) * sizeof(float) * n_tasks;
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -10322,15 +10322,17 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_state));
|
||||
|
||||
// scratch layout per thread: [q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)]
|
||||
const int64_t scratch_per_thread = 4 * S_v;
|
||||
// scratch layout per thread: [s_t(S_v*S_v) | q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)]
|
||||
// s_t holds the transposed (row-major) state for contiguous vector ops
|
||||
const int64_t scratch_per_thread = S_v * S_v + 4 * S_v;
|
||||
const int ith = params->ith;
|
||||
float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
|
||||
|
||||
float * q_local = scratch;
|
||||
float * k_local = scratch + S_v;
|
||||
float * kv_mem = scratch + 2 * S_v;
|
||||
float * delta = scratch + 3 * S_v;
|
||||
float * s_t = scratch;
|
||||
float * q_local = scratch + S_v * S_v;
|
||||
float * k_local = scratch + S_v * S_v + S_v;
|
||||
float * kv_mem = scratch + S_v * S_v + 2 * S_v;
|
||||
float * delta = scratch + S_v * S_v + 3 * S_v;
|
||||
|
||||
// output layout: [attn_scores | new_states]
|
||||
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
||||
|
|
@ -10352,12 +10354,19 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
const int64_t h_idx = ir % H;
|
||||
const int64_t sequence = ir / H;
|
||||
|
||||
// output state pointer for this (head, seq)
|
||||
// output state pointer for this (head, seq) — column-major (ggml layout)
|
||||
float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v;
|
||||
|
||||
// copy input state for this (head, seq) into output
|
||||
// Copy state into scratch in row-major layout of S (not S^T)
|
||||
// ggml column-major: s_in[j + i*S_v] = S[j][i] (j=dim0, i=dim1)
|
||||
// row-major of S: s_t[j * S_v + i] = S[j][i] (row j is contiguous over i)
|
||||
// This makes kv_mem[j] = dot(s_t[j*S_v:], k) a contiguous dot product
|
||||
const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v;
|
||||
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
s_t[j * S_v + i] = s_in[j + i * S_v];
|
||||
}
|
||||
}
|
||||
|
||||
// attn output pointer for first token of this (head, seq)
|
||||
float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
|
@ -10390,30 +10399,42 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps));
|
||||
|
||||
// state decay: S *= exp(g)
|
||||
ggml_vec_scale_f32(S_v * S_v, s_out, g_val);
|
||||
// s_t is row-major, but scaling all elements is layout-agnostic
|
||||
ggml_vec_scale_f32(S_v * S_v, s_t, g_val);
|
||||
|
||||
// kv_mem = S @ k
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_dot_f32(S_v, &kv_mem[i], 0, &s_out[i * S_v], 0, k_local, 0, 1);
|
||||
// kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k)
|
||||
// row j of s_t is contiguous -> use ggml_vec_dot_f32
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_dot_f32(S_v, &kv_mem[j], 0, &s_t[j * S_v], 0, k_local, 0, 1);
|
||||
}
|
||||
|
||||
// delta = (v - kv_mem) * beta
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
delta[i] = (v_d[i] - kv_mem[i]) * beta_val;
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
delta[j] = (v_d[j] - kv_mem[j]) * beta_val;
|
||||
}
|
||||
|
||||
// outer product update: S += k (x) delta
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_local[i]);
|
||||
// outer product: S[j][i] += k[i] * delta[j]
|
||||
// s_t[j * S_v + i] += k[i] * delta[j]
|
||||
// row j gets k[:] scaled by delta[j] -> contiguous ggml_vec_mad_f32
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_local, delta[j]);
|
||||
}
|
||||
|
||||
// attn output = S @ q
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_dot_f32(S_v, &attn_data[i], 0, &s_out[i * S_v], 0, q_local, 0, 1);
|
||||
// attn_out[j] = sum_i S[j][i] * q[i] = dot(s_t[j*S_v:], q)
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_local, 0, 1);
|
||||
}
|
||||
|
||||
attn_data += S_v * H; // advance to next token
|
||||
}
|
||||
|
||||
// copy scratch back to output: row-major of S -> column-major (ggml layout)
|
||||
// s_t[j * S_v + i] = S[j][i] -> s_out[j + i * S_v] = S[j][i]
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
s_out[j + i * S_v] = s_t[j * S_v + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue