metal: SSM kernel improvements (#17876)
* feat: Add a batched version of ssm_conv This was done using Claude Code. It found a number of optimizations around how the threads were organized, resulting in a huge performance boost! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Optimized SSM_SCAN kernel for metal This used Claude Code and resulted in a modest performance improvement while maintaining correctness. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * test: Add test-backend-ops perf tests for SSM_CONV Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * test: Real representitive tests for SSM_CONV Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use function constant for ssm_conv batch size Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * test: backend op tests for ssm_scan from granite4 1b-h Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * style: remove commented out templates Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: float4 version of ssm_conv_batched Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Add missing ggml_metal_cv_free Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
b63509262a
commit
086a63e3a5
|
|
@ -411,6 +411,38 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
|
||||||
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
const char * suffix = "";
|
||||||
|
if (op->src[1]->ne[0] % 4 == 0) {
|
||||||
|
suffix = "_4";
|
||||||
|
}
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
||||||
|
snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (!res.pipeline) {
|
||||||
|
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||||
|
|
||||||
|
ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||||
|
|
||||||
|
ggml_metal_cv_free(cv);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
|
||||||
|
|
@ -427,7 +459,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
|
||||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
res.smem = 32*sizeof(float)*nsg;
|
// Shared memory layout:
|
||||||
|
// - sgptg * NW floats for partial sums (nsg * 32)
|
||||||
|
// - sgptg floats for shared_x_dt (nsg)
|
||||||
|
// - sgptg floats for shared_dA (nsg)
|
||||||
|
// Total: nsg * (32 + 2) floats
|
||||||
|
res.smem = (32 + 2)*sizeof(float)*nsg;
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@
|
||||||
#define FC_MUL_MV 600
|
#define FC_MUL_MV 600
|
||||||
#define FC_MUL_MM 700
|
#define FC_MUL_MM 700
|
||||||
#define FC_ROPE 800
|
#define FC_ROPE 800
|
||||||
|
#define FC_SSM_CONV 900
|
||||||
|
|
||||||
// op-specific constants
|
// op-specific constants
|
||||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||||
|
|
|
||||||
|
|
@ -1365,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.nb2 =*/ nb2,
|
/*.nb2 =*/ nb2,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
|
||||||
|
const bool use_batched = (ne1 > 1);
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
if (use_batched) {
|
||||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
// Determine the smallest power of 2 that's >= ne1, but <= 256
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
int BATCH_SIZE;
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
if (ne1 > 128) BATCH_SIZE = 256;
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
else if (ne1 > 64 ) BATCH_SIZE = 128;
|
||||||
|
else if (ne1 > 32 ) BATCH_SIZE = 64;
|
||||||
|
else if (ne1 > 16 ) BATCH_SIZE = 32;
|
||||||
|
else if (ne1 > 8 ) BATCH_SIZE = 16;
|
||||||
|
else if (ne1 > 4 ) BATCH_SIZE = 8;
|
||||||
|
else BATCH_SIZE = 2;
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||||
|
|
||||||
|
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
|
||||||
|
// Each threadgroup has BATCH_SIZE threads, each handling one token
|
||||||
|
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
|
||||||
|
} else {
|
||||||
|
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4(
|
||||||
x[0] = sumf;
|
x[0] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
|
||||||
|
|
||||||
|
// Batched version: each threadgroup processes multiple tokens for better efficiency
|
||||||
|
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
|
||||||
|
kernel void kernel_ssm_conv_f32_f32_batched(
|
||||||
|
constant ggml_metal_kargs_ssm_conv & args,
|
||||||
|
device const void * src0,
|
||||||
|
device const void * src1,
|
||||||
|
device float * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
// tgpig.x = row index (ir)
|
||||||
|
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
||||||
|
// tgpig.z = sequence index (i3)
|
||||||
|
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
||||||
|
const short BATCH_SIZE = FC_ssm_conv_bs;
|
||||||
|
|
||||||
|
const int64_t ir = tgpig.x;
|
||||||
|
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
||||||
|
const int64_t i3 = tgpig.z;
|
||||||
|
const int64_t i2_off = tpitg.x;
|
||||||
|
const int64_t i2 = i2_base + i2_off;
|
||||||
|
|
||||||
|
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
||||||
|
const int64_t n_t = args.ne1; // number of tokens
|
||||||
|
|
||||||
|
// Bounds check for partial batches at the end
|
||||||
|
if (i2 >= n_t) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load conv weights (shared across all tokens for this row)
|
||||||
|
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||||
|
|
||||||
|
// Load source for this specific token
|
||||||
|
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||||
|
|
||||||
|
// Output location for this token
|
||||||
|
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
|
sumf += s[i0] * c[i0];
|
||||||
|
}
|
||||||
|
|
||||||
|
x[0] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_ssm_conv_f32_f32_batched_4(
|
||||||
|
constant ggml_metal_kargs_ssm_conv & args,
|
||||||
|
device const void * src0,
|
||||||
|
device const void * src1,
|
||||||
|
device float * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
// tgpig.x = row index (ir)
|
||||||
|
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
||||||
|
// tgpig.z = sequence index (i3)
|
||||||
|
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
||||||
|
const short BATCH_SIZE = FC_ssm_conv_bs;
|
||||||
|
|
||||||
|
const int64_t ir = tgpig.x;
|
||||||
|
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
||||||
|
const int64_t i3 = tgpig.z;
|
||||||
|
const int64_t i2_off = tpitg.x;
|
||||||
|
const int64_t i2 = i2_base + i2_off;
|
||||||
|
|
||||||
|
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
||||||
|
const int64_t n_t = args.ne1; // number of tokens
|
||||||
|
|
||||||
|
// Bounds check for partial batches at the end
|
||||||
|
if (i2 >= n_t) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load conv weights (shared across all tokens for this row)
|
||||||
|
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
||||||
|
|
||||||
|
// Load source for this specific token
|
||||||
|
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||||
|
|
||||||
|
// Output location for this token
|
||||||
|
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
||||||
|
sumf += dot(s[i0], c[i0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
x[0] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||||
|
// Optimized version: reduces redundant memory loads by having one thread load shared values
|
||||||
kernel void kernel_ssm_scan_f32(
|
kernel void kernel_ssm_scan_f32(
|
||||||
constant ggml_metal_kargs_ssm_scan & args,
|
constant ggml_metal_kargs_ssm_scan & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
|
|
@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32(
|
||||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||||
constexpr short NW = N_SIMDWIDTH;
|
constexpr short NW = N_SIMDWIDTH;
|
||||||
|
|
||||||
shared[tpitg.x] = 0.0f;
|
// Shared memory layout:
|
||||||
|
// [0..sgptg*NW-1]: partial sums for reduction (existing)
|
||||||
|
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
|
||||||
|
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
|
||||||
|
threadgroup float * shared_sums = shared;
|
||||||
|
threadgroup float * shared_x_dt = shared + sgptg * NW;
|
||||||
|
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
|
||||||
|
|
||||||
|
shared_sums[tpitg.x] = 0.0f;
|
||||||
|
|
||||||
const int32_t i0 = tpitg.x;
|
const int32_t i0 = tpitg.x;
|
||||||
const int32_t i1 = tgpig.x;
|
const int32_t i1 = tgpig.x;
|
||||||
|
|
@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32(
|
||||||
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
// Pre-compute x_dt and dA for this batch of tokens
|
||||||
const float dt0 = dt[0];
|
// Only first sgptg threads do the loads and expensive math
|
||||||
|
if (i0 < sgptg && i2 + i0 < n_t) {
|
||||||
|
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
|
||||||
|
device const float * x_t = x + i0 * args.ns12;
|
||||||
|
device const float * dt_t = dt + i0 * args.ns21;
|
||||||
|
|
||||||
|
const float dt0 = dt_t[0];
|
||||||
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
||||||
const float x_dt = x[0] * dtsp;
|
shared_x_dt[i0] = x_t[0] * dtsp;
|
||||||
const float dA = exp(dtsp * A0);
|
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
||||||
|
const float x_dt = shared_x_dt[t];
|
||||||
|
const float dA = exp(shared_dA[t] * A0);
|
||||||
|
|
||||||
s = (s0 * dA) + (B[i0] * x_dt);
|
s = (s0 * dA) + (B[i0] * x_dt);
|
||||||
|
|
||||||
const float sumf = simd_sum(s * C[i0]);
|
const float sumf = simd_sum(s * C[i0]);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
shared[t*NW + sgitg] = sumf;
|
shared_sums[t*NW + sgitg] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// recurse
|
// recurse
|
||||||
s0 = s;
|
s0 = s;
|
||||||
|
|
||||||
x += args.ns12;
|
|
||||||
dt += args.ns21;
|
|
||||||
B += args.ns42;
|
B += args.ns42;
|
||||||
C += args.ns52;
|
C += args.ns52;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Advance pointers for next batch
|
||||||
|
x += sgptg * args.ns12;
|
||||||
|
dt += sgptg * args.ns21;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
|
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
|
||||||
|
|
||||||
if (tiisg == 0 && i2 + sgitg < n_t) {
|
if (tiisg == 0 && i2 + sgitg < n_t) {
|
||||||
y[sgitg*nh*nr] = sumf;
|
y[sgitg*nh*nr] = sumf;
|
||||||
|
|
|
||||||
|
|
@ -8193,6 +8193,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
|
||||||
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
|
||||||
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
|
||||||
|
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
|
||||||
|
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
|
||||||
|
|
||||||
|
|
||||||
return test_cases;
|
return test_cases;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue