metal : Fix dimension constraint violation in matmul2d descriptor (#21048)
Updates Metal tensor API test probe to fix the dimension constraint violation in the matmul2d descriptor (at least one value must be a multiple of 16).
This commit is contained in:
parent
6861f6509a
commit
9bcb4eff4d
|
|
@ -690,7 +690,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
|||
" auto tB = B.slice((int)tgid.x, 0); \n"
|
||||
" \n"
|
||||
" matmul2d< \n"
|
||||
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
||||
" matmul2d_descriptor(16, 16, dynamic_extent), \n"
|
||||
" execution_simdgroups<4>> mm; \n"
|
||||
" \n"
|
||||
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
||||
|
|
@ -740,7 +740,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
|||
" auto tB = B.slice((int)tgid.x, 0); \n"
|
||||
" \n"
|
||||
" matmul2d< \n"
|
||||
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
||||
" matmul2d_descriptor(16, 16, dynamic_extent), \n"
|
||||
" execution_simdgroups<4>> mm; \n"
|
||||
" \n"
|
||||
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue