feat: Use local variable for state recursion
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
a5334f911e
commit
3866f766fe
|
|
@ -1788,8 +1788,11 @@ kernel void kernel_ssm_scan_f32_group(
|
||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src6;
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
const int64_t i = tpitg.x + i1*nc;
|
||||||
|
float s0 = s0_buff[i];
|
||||||
|
float s = s_buff[i];
|
||||||
|
|
||||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||||
|
|
@ -1809,9 +1812,8 @@ kernel void kernel_ssm_scan_f32_group(
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
const float dA = exp(dt_soft_plus * A[0]);
|
const float dA = exp(dt_soft_plus * A[0]);
|
||||||
|
|
||||||
const int64_t i = tpitg.x + i1*nc;
|
const float state = (s0 * dA) + (B[tpitg.x] * x_dt);
|
||||||
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
|
s = state;
|
||||||
s[i] = state;
|
|
||||||
|
|
||||||
// Parallel sum: This relies on the fact that this kernel will be
|
// Parallel sum: This relies on the fact that this kernel will be
|
||||||
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
||||||
|
|
@ -1851,6 +1853,9 @@ kernel void kernel_ssm_scan_f32_group(
|
||||||
// recurse
|
// recurse
|
||||||
s0 = s;
|
s0 = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign the final state to the output buffer
|
||||||
|
s_buff[i] = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_rwkv_wkv6_f32(
|
kernel void kernel_rwkv_wkv6_f32(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue