When the number of cols is large, split each row across multiple workgroups. There are three phases that communicate partial results through temp buffers: (1) compute max partials (2) take max of partials, compute sum(exp(x-max)) partials (3) sum partials, compute scaled result |
||
|---|---|---|
| .. | ||
| cmake | ||
| vulkan-shaders | ||
| CMakeLists.txt | ||
| ggml-vulkan.cpp | ||