HIP: Patch failed testcase in WMMA-MMQ kernels for RDNA 4 (#17502)
* patch failed test case MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) for enabling WMMA on RDNA4 * Quick clean up on mma.cuh to add ggml_cuda_memcpy_1 back in for half2 and bfloat162
This commit is contained in:
parent
eeb5605de2
commit
3e18dba9fd
|
|
@ -437,10 +437,15 @@ namespace ggml_cuda_mma {
|
|||
xi[0] = xs[0];
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
|
||||
} else if constexpr (std::is_same_v<T, int>) {
|
||||
if constexpr (I == 16 && J == 4) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
|
||||
}else if constexpr (I == 16 && J == 8) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||
|
|
@ -448,6 +453,10 @@ namespace ggml_cuda_mma {
|
|||
|
||||
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
||||
xi[1] = xs1[0];
|
||||
|
||||
}else{
|
||||
NO_DEVICE_CODE;
|
||||
}
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
|||
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
||||
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
||||
const size_t nbs_ids = mmq_x*sizeof(int);
|
||||
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
||||
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue