revert lds128 for wmma loading

This commit is contained in:
zhang hui 2025-11-25 17:06:40 +08:00
parent ae1c500aae
commit db9ae8b6b4
1 changed files with 2 additions and 7 deletions

View File

@ -441,13 +441,8 @@ namespace ggml_cuda_mma {
int64_t * xi = (int64_t *) t.x; int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0]; xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) { }else if constexpr (I == 16 && J == 8 && (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>)) {
int64_t * xi = (int64_t *) t.x; ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{ }else{
NO_DEVICE_CODE; NO_DEVICE_CODE;
} }