revert lds128 for wmma loading
This commit is contained in:
parent
ae1c500aae
commit
db9ae8b6b4
|
|
@ -441,13 +441,8 @@ namespace ggml_cuda_mma {
|
|||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
}else if constexpr (I == 16 && J == 8) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
|
||||
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
||||
xi[1] = xs1[0];
|
||||
}else if constexpr (I == 16 && J == 8 && (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>)) {
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
}else{
|
||||
NO_DEVICE_CODE;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue