metal : single-user mamba2 inference works
This commit is contained in:
parent
6def5cd729
commit
791998b42d
|
|
@ -1284,7 +1284,7 @@ kernel void kernel_ssm_scan_f32(
|
||||||
const int64_t ng = args.n_group;
|
const int64_t ng = args.n_group;
|
||||||
const int64_t n_t = args.n_seq_tokens;
|
const int64_t n_t = args.n_seq_tokens;
|
||||||
|
|
||||||
const int64_t s_off = nr * nh * nt * args.n_seqs * sizeof(float);
|
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src6;
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
|
|
@ -1292,12 +1292,12 @@ kernel void kernel_ssm_scan_f32(
|
||||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
||||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns}
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
|
|
@ -1354,12 +1354,12 @@ kernel void kernel_ssm_scan_f32_group(
|
||||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns}
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
|
|
|
||||||
|
|
@ -9009,7 +9009,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
|
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
|
||||||
|
|
||||||
// {n_head, n_seq_tokens, n_seqs}
|
// {n_head, n_seq_tokens, n_seqs}
|
||||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
|
||||||
|
|
||||||
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
|
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
|
||||||
// TODO: use semistructured matrices to implement state-space duality
|
// TODO: use semistructured matrices to implement state-space duality
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue