cpu: support for non-contig q,k,v
This commit is contained in:
parent
54ea122385
commit
15d83e0c87
|
|
@ -2474,8 +2474,7 @@ extern "C" {
|
|||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state,
|
||||
float eps);
|
||||
struct ggml_tensor * state);
|
||||
|
||||
// custom operators
|
||||
|
||||
|
|
|
|||
|
|
@ -2912,7 +2912,7 @@ struct ggml_cplan ggml_graph_plan(
|
|||
} break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const int64_t S_v = node->src[0]->ne[0];
|
||||
const int64_t S_v = node->src[2]->ne[0];
|
||||
cur = (S_v * S_v + S_v) * sizeof(float) * n_tasks;
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
|
|
|
|||
|
|
@ -10310,22 +10310,35 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
ggml_tensor * src_beta = dst->src[4];
|
||||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
const int64_t S_v = src_q->ne[0];
|
||||
const int64_t H = src_q->ne[1];
|
||||
const int64_t n_tokens = src_q->ne[2];
|
||||
const int64_t n_seqs = src_q->ne[3];
|
||||
const int64_t S_v = src_v->ne[0];
|
||||
const int64_t H = src_v->ne[1];
|
||||
const int64_t n_tokens = src_v->ne[2];
|
||||
const int64_t n_seqs = src_v->ne[3];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src_q));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_k));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_v));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_g));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_state));
|
||||
|
||||
// TODO: to support KDA
|
||||
GGML_ASSERT(ggml_are_same_shape(src_beta, src_g));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
|
||||
GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
|
||||
|
||||
// scratch layout per thread: [s_t(S_v*S_v) | delta(S_v)]
|
||||
// s_t holds the transposed (row-major) state for contiguous vector ops
|
||||
const int64_t scratch_per_thread = S_v * S_v + S_v;
|
||||
const int ith = params->ith;
|
||||
|
||||
float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
|
||||
|
||||
float * s_t = scratch;
|
||||
|
|
@ -10339,20 +10352,28 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
float * state_out_base = (float *)dst->data + attn_score_elems;
|
||||
|
||||
const float * state_in_base = (const float *)src_state->data;
|
||||
const float * q_base = (const float *)src_q->data;
|
||||
const float * k_base = (const float *)src_k->data;
|
||||
const float * v_base = (const float *)src_v->data;
|
||||
const float * g_base = (const float *)src_g->data;
|
||||
const float * beta_base = (const float *)src_beta->data;
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t h_idx = ir % H;
|
||||
const int64_t sequence = ir / H;
|
||||
const int64_t rq1 = nev1 / neq1;
|
||||
const int64_t rk1 = nev1 / nek1;
|
||||
const int64_t rq3 = nev3 / neq3;
|
||||
const int64_t rk3 = nev3 / nek3;
|
||||
|
||||
float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v;
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t iv1 = ir % H; // head_index
|
||||
const int64_t iv3 = ir / H; // sequence
|
||||
|
||||
const int64_t iq1 = iv1 / rq1;
|
||||
const int64_t ik1 = iv1 / rk1;
|
||||
|
||||
const int64_t iq3 = iv3 / rq3;
|
||||
const int64_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
// tranpose
|
||||
const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v;
|
||||
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
s_t[j * S_v + i] = s_in[j + i * S_v];
|
||||
|
|
@ -10360,17 +10381,14 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
}
|
||||
|
||||
// attn output pointer for first token of this (head, seq)
|
||||
float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v;
|
||||
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
// input pointers for this (head, seq, token)
|
||||
// layout is contiguous [S_v, H, n_tokens, n_seqs]
|
||||
const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v;
|
||||
const float * q_d = q_base + qkv_offset;
|
||||
const float * k_d = k_base + qkv_offset;
|
||||
const float * v_d = v_base + qkv_offset;
|
||||
const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
|
||||
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
|
||||
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
|
||||
|
||||
const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx;
|
||||
const int64_t gb_offset = iv3 * neg1 * neg0 + t * neg0 + iv1;
|
||||
const float beta_val_raw = beta_base[gb_offset];
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid
|
||||
const float g_val = expf(g_base[gb_offset]);
|
||||
|
|
@ -10410,8 +10428,8 @@ static void ggml_compute_forward_gated_delta_net_f32(
|
|||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
int64_t nr = Q->ne[1] * Q->ne[3];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
int64_t nr = V->ne[1] * V->ne[3];
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
|
|
|||
|
|
@ -6112,8 +6112,7 @@ struct ggml_tensor * ggml_gated_delta_net(
|
|||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state,
|
||||
float eps) {
|
||||
struct ggml_tensor * state) {
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
|
|
@ -6128,17 +6127,11 @@ struct ggml_tensor * ggml_gated_delta_net(
|
|||
GGML_ASSERT(beta->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(state->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H = v->ne[1];
|
||||
const int64_t n_tokens = v->ne[2];
|
||||
const int64_t n_seqs = v->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == H && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == H && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs);
|
||||
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
|
||||
|
||||
// concat output and new_state into a single tensor
|
||||
|
|
@ -6146,8 +6139,6 @@ struct ggml_tensor * ggml_gated_delta_net(
|
|||
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, eps);
|
||||
|
||||
result->op = GGML_OP_GATED_DELTA_NET;
|
||||
result->src[0] = q;
|
||||
result->src[1] = k;
|
||||
|
|
|
|||
|
|
@ -3643,23 +3643,24 @@ struct test_gated_delta_net : public test_case {
|
|||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
const int v_repeat;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
return VARS_TO_STR6(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat);
|
||||
}
|
||||
|
||||
test_gated_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, int v_repeat = 1)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), v_repeat(v_repeat) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * head_size * head_count, n_seqs);
|
||||
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, 1e-6f);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);
|
||||
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
|
@ -8343,7 +8344,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
|
|
|
|||
Loading…
Reference in New Issue