ggml-cpu: add DELTA_NET backend + tests
This commit is contained in:
parent
0a192937a1
commit
128a6c2831
|
|
@ -2014,6 +2014,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
ggml_compute_forward_delta_net(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
|
|
@ -2339,6 +2343,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -10091,6 +10091,139 @@ void ggml_compute_forward_rwkv_wkv7(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_delta_net
|
||||
|
||||
static void ggml_compute_forward_delta_net_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
const ggml_tensor * src3 = dst->src[3];
|
||||
const ggml_tensor * src4 = dst->src[4];
|
||||
const ggml_tensor * src5 = dst->src[5];
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_tokens = src0->ne[1];
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
const int64_t n_seqs = src0->ne[3];
|
||||
|
||||
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
|
||||
|
||||
const float * q_data = (const float *) src0->data;
|
||||
const float * k_data = (const float *) src1->data;
|
||||
const float * v_data = (const float *) src2->data;
|
||||
const float * g_data = (const float *) src3->data;
|
||||
const float * beta_data = (const float *) src4->data;
|
||||
const float * state_in = (const float *) src5->data;
|
||||
float * out_data = (float *) dst->data;
|
||||
float * state_out = out_data + output_size;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t total_heads = n_heads * n_seqs;
|
||||
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int64_t h_start = ith * heads_per_thread;
|
||||
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
const float eps = 1e-12f;
|
||||
const float scale = 1.0f / sqrtf((float)head_dim);
|
||||
|
||||
float * v_new_buf = (float *)malloc(head_dim * sizeof(float));
|
||||
if (!v_new_buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64_t h_idx = h_start; h_idx < h_end; h_idx++) {
|
||||
const int64_t batch_idx = h_idx / n_heads;
|
||||
const int64_t head_idx = h_idx % n_heads;
|
||||
|
||||
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int64_t qkv_token_stride = head_dim;
|
||||
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int64_t out_token_stride = head_dim * n_heads;
|
||||
|
||||
for (int64_t i = 0; i < head_dim * head_dim; i++) {
|
||||
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||
}
|
||||
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
float g_val = g_data[g_head_offset + t];
|
||||
float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float q_norm_sq = 0.0f, k_norm_sq = 0.0f;
|
||||
for (int64_t i = 0; i < head_dim; i++) {
|
||||
q_norm_sq += q_t[i] * q_t[i];
|
||||
k_norm_sq += k_t[i] * k_t[i];
|
||||
}
|
||||
float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||
float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||
|
||||
float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = 0.0f;
|
||||
for (int64_t i = 0; i < head_dim; i++) {
|
||||
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
}
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
for (int64_t row = 0; row < head_dim; row++) {
|
||||
float v_prime = 0.0f;
|
||||
float out_val = 0.0f;
|
||||
|
||||
for (int64_t col = 0; col < head_dim; col++) {
|
||||
float k_col = k_t[col] * k_norm_inv;
|
||||
float q_col = q_t[col] * q_norm_inv * scale;
|
||||
float s = state[row + col * head_dim];
|
||||
|
||||
v_prime += s * k_col * beta_val * decay;
|
||||
out_val += s * q_col * decay;
|
||||
}
|
||||
|
||||
float v_new = v_t[row] * beta_val - v_prime;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val + v_new * attn_score;
|
||||
}
|
||||
|
||||
for (int64_t col = 0; col < head_dim; col++) {
|
||||
float k_col = k_t[col] * k_norm_inv;
|
||||
for (int64_t row = 0; row < head_dim; row++) {
|
||||
float s = state[row + col * head_dim];
|
||||
s = decay * s + v_new_buf[row] * k_col;
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(v_new_buf);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_delta_net(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
ggml_compute_forward_delta_net_f32(params, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom1
|
||||
|
||||
void ggml_compute_forward_map_custom1(
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
|
|||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -469,6 +469,15 @@ private:
|
|||
ggml_tensor * state,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
|
|
|
|||
|
|
@ -3550,6 +3550,34 @@ struct test_rwkv_wkv7 : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_DELTA_NET
|
||||
struct test_delta_net : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t n_heads;
|
||||
const int64_t head_dim;
|
||||
const int64_t n_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, n_heads, head_dim, n_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t n_heads = 8, int64_t head_dim = 64, int64_t n_tokens = 32, int64_t n_seqs = 2)
|
||||
: type(type), n_heads(n_heads), head_dim(head_dim), n_tokens(n_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, n_tokens, 1, n_heads, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_dim, head_dim * n_heads, 1, n_seqs);
|
||||
return ggml_delta_net(ctx, q, k, v, g, beta, state);
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
|
|
@ -7322,6 +7350,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 2));
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 128, 2));
|
||||
|
||||
#if 0
|
||||
// > 4GB A matrix. Too slow to be enabled by default.
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
|
||||
|
|
|
|||
Loading…
Reference in New Issue