diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d588..82b27d2343 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -556,6 +556,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_GATED_DELTA_NET, GGML_OP_UNARY, @@ -2466,6 +2467,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b003fe13fd..f11d7eab8f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_GATED_DELTA_NET: + { + ggml_compute_forward_gated_delta_net(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_COUNT_EQUAL: case GGML_OP_SOLVE_TRI: + case GGML_OP_GATED_DELTA_NET: { n_tasks = n_threads; } break; @@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_GATED_DELTA_NET: + { + const int64_t S_v = node->src[0]->ne[0]; + cur = 4 * S_v * sizeof(float) * n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ed45350207..0069d8c0f5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10296,6 +10296,186 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s } } +// ggml_compute_forward_gated_delta_net +static void ggml_compute_forward_gated_delta_net_one_chunk( + const ggml_compute_params * params, + ggml_tensor * dst, + int64_t ir0, + int64_t ir1) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_q->ne[0]; + const int64_t H = src_q->ne[1]; + const int64_t n_tokens = src_q->ne[2]; + const int64_t n_seqs = src_q->ne[3]; + + GGML_ASSERT(ggml_is_contiguous(src_q)); + GGML_ASSERT(ggml_is_contiguous(src_k)); + GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // scratch layout per thread: [q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)] + const int64_t scratch_per_thread = 4 * S_v; + const int ith = params->ith; + float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + + float * q_local = scratch; + float * k_local = scratch + S_v; + float * kv_mem = scratch + 2 * S_v; + float * delta = scratch + 3 * S_v; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs floats + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_out_base = (float *)dst->data; + float * state_out_base = (float *)dst->data + attn_score_elems; + + const float * state_in_base = (const float *)src_state->data; + const float * q_base = (const float *)src_q->data; + const float * k_base = (const float *)src_k->data; + const float * v_base = (const float *)src_v->data; + const float * g_base = (const float *)src_g->data; + const float * beta_base = (const float *)src_beta->data; + + const float eps = ggml_get_op_params_f32(dst, 0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t h_idx = ir % H; + const int64_t sequence = ir / H; + + // output state pointer for this (head, seq) + float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v; + + // copy input state for this (head, seq) into output + const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v; + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v; + + for (int64_t t = 0; t < n_tokens; t++) { + // input pointers for this (head, seq, token) + // layout is contiguous [S_v, H, n_tokens, n_seqs] + const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; + const float * q_d = q_base + qkv_offset; + const float * k_d = k_base + qkv_offset; + const float * v_d = v_base + qkv_offset; + + // g and beta layout: [H, n_tokens, n_seqs] + const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; + const float beta_val_raw = beta_base[gb_offset]; + const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid + const float g_val = expf(g_base[gb_offset]); + + memcpy(q_local, q_d, S_v * sizeof(float)); + memcpy(k_local, k_d, S_v * sizeof(float)); + + // l2-norm q and scale by 1/sqrt(S_v) + float norm; + ggml_vec_norm_f32(S_v, &norm, q_local); + ggml_vec_scale_f32(S_v, q_local, 1.0f / fmaxf(norm, eps)); + ggml_vec_scale_f32(S_v, q_local, 1.0f / sqrtf((float)S_v)); + + // l2-norm k + ggml_vec_norm_f32(S_v, &norm, k_local); + ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps)); + + // state decay: S *= exp(g) + ggml_vec_scale_f32(S_v * S_v, s_out, g_val); + + // kv_mem = S @ k + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_dot_f32(S_v, &kv_mem[i], 0, &s_out[i * S_v], 0, k_local, 0, 1); + } + + // delta = (v - kv_mem) * beta + for (int64_t i = 0; i < S_v; ++i) { + delta[i] = (v_d[i] - kv_mem[i]) * beta_val; + } + + // outer product update: S += k (x) delta + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_local[i]); + } + + // attn output = S @ q + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_dot_f32(S_v, &attn_data[i], 0, &s_out[i * S_v], 0, q_local, 0, 1); + } + + attn_data += S_v * H; // advance to next token + } + } +} + + +static void ggml_compute_forward_gated_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + ggml_tensor * Q = dst->src[0]; + int64_t nr = Q->ne[1] * Q->ne[3]; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + int nth = params->nth; + int ith = params->ith; + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } +} + +void ggml_compute_forward_gated_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gated_delta_net_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..3fa1443abc 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 500cb6b72f..86b7e8741a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "GATED_DELTA_NET", "UNARY", @@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "gated_delta_net(q, k, v, g, beta, s)", "unary(x)", @@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6101,6 +6103,59 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// ggml_gated_delta_net + +struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_k = q->ne[0]; + const int64_t H = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == H && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == H && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs); + GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); + + // concat output and new_state into a single tensor + // output: S_v * H * n_tokens, state: S_v * S_v * H * n_seqs + const int64_t ne[4] = { S_v * H, n_tokens + S_v * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_GATED_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) {