vulkan: fix MMQ shader push constants and multi-dispatch (#19732)
This commit is contained in:
parent
da348c9dfb
commit
abb9f3c42b
|
|
@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
|
|||
uint nbi1;
|
||||
uint ne11;
|
||||
#else
|
||||
uint base_work_group_z;
|
||||
uint num_batches;
|
||||
uint k_split;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
|
|
@ -108,7 +110,7 @@ void main() {
|
|||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||
const uint expert_idx = gl_WorkGroupID.z;
|
||||
if (ic * BN >= data_expert_count[expert_idx]) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -118,7 +120,7 @@ void main() {
|
|||
#endif
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
const uint batch_idx = gl_GlobalInvocationID.z;
|
||||
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
|
||||
|
||||
const uint i13 = batch_idx / p.ne12;
|
||||
const uint i12 = batch_idx % p.ne12;
|
||||
|
|
@ -276,7 +278,7 @@ void main() {
|
|||
const uint dc = ic * BN + warp_c * WN;
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue