334 lines
12 KiB
C++
334 lines
12 KiB
C++
#include "models.h"
|
|
|
|
#define CHUNK_SIZE 64
|
|
|
|
// utility to get one slice from the third dimension
|
|
// input dim: [x, y, c, b]
|
|
// output dim: [x, y, 1, b]
|
|
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
|
|
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
|
|
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
|
}
|
|
|
|
llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
|
|
|
|
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
|
|
ggml_tensor * q,
|
|
ggml_tensor * k,
|
|
ggml_tensor * v,
|
|
ggml_tensor * g,
|
|
ggml_tensor * b,
|
|
ggml_tensor * s,
|
|
int il) {
|
|
const int64_t S_k = q->ne[0];
|
|
const int64_t H_k = q->ne[1];
|
|
const int64_t n_tokens = q->ne[2];
|
|
const int64_t n_seqs = q->ne[3];
|
|
|
|
const int64_t S_v = v->ne[0];
|
|
const int64_t H_v = v->ne[1];
|
|
|
|
GGML_ASSERT(S_k == S_v);
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
|
|
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
|
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
|
|
|
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
|
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
|
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
|
|
|
const float scale = 1.0f / sqrtf(S_k);
|
|
|
|
q = ggml_scale(ctx0, q, scale);
|
|
|
|
cb(q, "q_in", il);
|
|
cb(k, "k_in", il);
|
|
cb(v, "v_in", il);
|
|
cb(b, "b_in", il);
|
|
cb(g, "g_in", il);
|
|
|
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
|
g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs]
|
|
b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
|
|
|
|
const int CS = CHUNK_SIZE;
|
|
|
|
const int pad = (CS - n_tokens % CS) % CS;
|
|
const int n_chunks = (n_tokens + pad) / CS;
|
|
|
|
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
|
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
|
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
|
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
|
|
b = ggml_pad(ctx0, b, 0, pad, 0, 0);
|
|
|
|
ggml_tensor * v_b = ggml_mul(ctx0, v, b);
|
|
ggml_tensor * k_b = ggml_mul(ctx0, k, b);
|
|
|
|
cb(v_b, "v_b", il);
|
|
cb(k_b, "k_b", il);
|
|
|
|
q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs);
|
|
k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs);
|
|
k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
|
|
v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
|
|
v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
|
|
|
|
g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
|
|
b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
|
|
|
|
// [CS, 1, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
|
|
cb(g_cs, "g_cs", il);
|
|
|
|
ggml_tensor * g_cs_i = g_cs;
|
|
ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
|
|
|
|
g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
|
|
|
|
// [CS, CS, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * decay_mask;
|
|
decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
|
decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
|
decay_mask = ggml_exp(ctx0, decay_mask);
|
|
cb(decay_mask, "decay_mask", il);
|
|
|
|
// [CS, CS, n_chunks, H_k * n_seqs]
|
|
ggml_tensor * kb;
|
|
kb = ggml_mul_mat(ctx0, k, k_b);
|
|
kb = ggml_mul (ctx0, kb, decay_mask);
|
|
|
|
// [CS, CS, n_chunks, H_k * n_seqs]
|
|
ggml_tensor * attn;
|
|
attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
|
|
|
|
ggml_tensor * identity;
|
|
identity = ggml_view_1d(ctx0, attn, CS, 0);
|
|
identity = ggml_fill (ctx0, identity, 1.0f);
|
|
identity = ggml_diag (ctx0, identity);
|
|
|
|
ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
|
|
cb(lhs, "dnet_add_ch_lhs", il);
|
|
|
|
attn = ggml_neg(ctx0, attn);
|
|
|
|
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
|
attn = ggml_add(ctx0, lin_solve, identity);
|
|
cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
|
|
|
|
// [S_v, CS, n_chunks, H_v * n_seqs]
|
|
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
|
|
|
|
// [CS, 1, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
|
|
|
|
k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
|
|
|
|
// [CS, S_k, n_chunks, H_k * n_seqs]
|
|
ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
|
|
cb(kbg, "k_beta_g_exp", il);
|
|
|
|
// [S_k, CS, n_chunks, H_k * n_seqs]
|
|
ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
|
|
cb(k_cd, "k_cumdecay", il);
|
|
|
|
// [S_k, CS, n_chunks, H_k * n_seqs]
|
|
ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
|
|
ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
|
|
|
|
// [CS, CS, n_chunks, H_k * n_seqs]
|
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
kq = ggml_mul(ctx0, kq, decay_mask);
|
|
kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
|
|
cb(kq, "kq", il);
|
|
|
|
// vectorized calculation of key_gdiff
|
|
// improved from the chunked version:
|
|
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
|
|
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
|
|
// key_gdiff = key * g_diff.unsqueeze(-1)
|
|
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
|
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
|
|
|
// get last element in g_cumsum along CS dimension (ne0)
|
|
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
|
// [1, 1, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
|
|
g_cs->nb[1],
|
|
g_cs->nb[2],
|
|
g_cs->nb[3],
|
|
ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
|
|
cb(g_last, "g_last", il);
|
|
|
|
// TODO: remove this cont when CUDA supports non-cont unary ops
|
|
g_last = ggml_cont(ctx0, g_last);
|
|
|
|
// [1, 1, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
|
cb(g_last_exp, "g_last_exp", il);
|
|
|
|
// [CS, 1, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
|
|
cb(g_diff, "g_diff", il);
|
|
|
|
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
|
ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
|
|
|
|
// [S_k, CS, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
|
|
cb(kg, "key_gdiff", il);
|
|
|
|
// [CS, S_k, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
|
|
cb(kg_t, "key_gdiff_t", il);
|
|
|
|
ggml_tensor * s_t = ggml_transpose(ctx0, s);
|
|
s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
|
|
cb(s_t, "dnet_add_ch_state", il);
|
|
|
|
// [CS, S_v, n_chunks, H_v * n_seqs]
|
|
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
|
|
|
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
|
ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs]
|
|
ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs]
|
|
ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs]
|
|
ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs]
|
|
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
|
|
|
|
// [CS, S_v, 1, H_v * n_seqs]
|
|
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
|
|
cb(v_t_p, "v_prime", il);
|
|
|
|
// [CS, S_v, 1, H_v * n_seqs]
|
|
ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
|
|
cb(v_t_new, "v_t_new", il);
|
|
|
|
// [S_v, CS, 1, H_v * n_seqs]
|
|
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
|
|
cb(v_attn, "v_attn", il);
|
|
|
|
// [S_v, CS, 1, H_v * n_seqs]
|
|
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
|
|
cb(attn_inter, "attn_inter", il);
|
|
|
|
// [S_v, CS, 1, H_v * n_seqs]
|
|
ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
|
|
cb(o_ch, "dnet_add_ch_attn_out", il);
|
|
|
|
v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
|
|
|
|
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
|
// TODO: head broadcast might not work here - probably will need a transpose
|
|
ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
|
|
|
|
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
|
ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
|
|
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
|
|
s_t = ggml_add(ctx0, s_t, kgv);
|
|
cb(s_t, "dnet_add_ch_state", il);
|
|
}
|
|
|
|
s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
|
|
|
|
// truncate padded tokens
|
|
ggml_tensor * o = ggml_view_4d(ctx0, v,
|
|
S_v, n_tokens, H_v, n_seqs,
|
|
ggml_row_size(v->type, S_v),
|
|
ggml_row_size(v->type, S_v * CS * n_chunks),
|
|
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
|
|
|
|
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
|
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
|
|
|
return {o, s};
|
|
}
|
|
|
|
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
|
|
ggml_tensor * q,
|
|
ggml_tensor * k,
|
|
ggml_tensor * v,
|
|
ggml_tensor * g,
|
|
ggml_tensor * b, // beta
|
|
ggml_tensor * s, // state
|
|
int il) {
|
|
const int64_t S_k = q->ne[0];
|
|
const int64_t H_k = q->ne[1];
|
|
const int64_t n_tokens = q->ne[2];
|
|
const int64_t n_seqs = q->ne[3];
|
|
|
|
const int64_t S_v = v->ne[0];
|
|
const int64_t H_v = v->ne[1];
|
|
|
|
GGML_ASSERT(n_tokens == 1);
|
|
|
|
GGML_ASSERT(S_k == S_v);
|
|
GGML_ASSERT(H_v % H_k == 0);
|
|
|
|
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
|
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
|
|
|
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
|
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
|
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
|
|
|
const float scale = 1.0f / sqrtf(S_k);
|
|
|
|
q = ggml_scale(ctx0, q, scale);
|
|
|
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
|
|
|
cb(q, "q_in", il);
|
|
cb(k, "k_in", il);
|
|
cb(v, "v_in", il);
|
|
cb(b, "b_in", il);
|
|
cb(g, "g_in", il);
|
|
|
|
g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
|
|
b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
|
|
|
|
// [S_v, S_v, H_v, n_seqs]
|
|
g = ggml_exp(ctx0, g);
|
|
s = ggml_mul(ctx0, s, g);
|
|
|
|
ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
|
|
|
|
// [1, S_v, H_v, n_seqs]
|
|
ggml_tensor * sk;
|
|
sk = ggml_mul (ctx0, s_t, k);
|
|
sk = ggml_sum_rows(ctx0, sk);
|
|
|
|
// [S_v, 1, H_v, n_seqs]
|
|
ggml_tensor * d;
|
|
d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
|
|
d = ggml_mul(ctx0, d, b);
|
|
|
|
// [1, S_v, H_v, n_seqs]
|
|
ggml_tensor * d_t;
|
|
d_t = ggml_transpose(ctx0, d);
|
|
|
|
// [S_v, S_v, H_v, n_seqs]
|
|
ggml_tensor * kd;
|
|
k = ggml_repeat(ctx0, k, s);
|
|
kd = ggml_mul (ctx0, k, d_t);
|
|
|
|
s_t = ggml_add(ctx0, s_t, kd);
|
|
|
|
cb(s_t, "dnet_add_ar_state", il);
|
|
|
|
ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
|
|
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
|
|
|
|
o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
|
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
|
|
|
return {o, s};
|
|
}
|