Merge branch 'master' into compilade/refactor-kv-cache

This commit is contained in:
Francis Couture-Harpin 2025-07-03 16:03:56 -04:00
commit 4682e21c46
89 changed files with 895 additions and 4844 deletions

View File

@ -40,7 +40,7 @@ body:
attributes: attributes:
label: GGML backends label: GGML backends
description: Which GGML backends do you know to be affected? description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan, OpenCL] options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
multiple: true multiple: true
validations: validations:
required: true required: true

View File

@ -42,7 +42,7 @@ body:
attributes: attributes:
label: GGML backends label: GGML backends
description: Which GGML backends do you know to be affected? description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan, OpenCL] options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
multiple: true multiple: true
validations: validations:
required: true required: true

6
.github/labeler.yml vendored
View File

@ -1,10 +1,4 @@
# https://github.com/actions/labeler # https://github.com/actions/labeler
Kompute:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml-kompute.h
- ggml/src/ggml-kompute/**
- README-kompute.md
Apple Metal: Apple Metal:
- changed-files: - changed-files:
- any-glob-to-any-file: - any-glob-to-any-file:

View File

@ -740,9 +740,6 @@ jobs:
- build: 'llvm-arm64-opencl-adreno' - build: 'llvm-arm64-opencl-adreno'
arch: 'arm64' arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
# - build: 'kompute-x64'
# arch: 'x64'
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON'
steps: steps:
- name: Clone - name: Clone
@ -756,12 +753,6 @@ jobs:
variant: ccache variant: ccache
evict-old-files: 1d evict-old-files: 1d
- name: Clone Kompute submodule
id: clone_kompute
if: ${{ matrix.build == 'kompute-x64' }}
run: |
git submodule update --init ggml/src/ggml-kompute/kompute
- name: Download OpenBLAS - name: Download OpenBLAS
id: get_openblas id: get_openblas
if: ${{ matrix.build == 'openblas-x64' }} if: ${{ matrix.build == 'openblas-x64' }}
@ -777,7 +768,7 @@ jobs:
- name: Install Vulkan SDK - name: Install Vulkan SDK
id: get_vulkan id: get_vulkan
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }} if: ${{ matrix.build == 'vulkan-x64' }}
run: | run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe" curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "kompute"]
path = ggml/src/ggml-kompute/kompute
url = https://github.com/nomic-ai/kompute.git

View File

@ -120,7 +120,6 @@ endfunction()
llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA) llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA)
llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA) llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA)
llama_option_depr(WARNING LLAMA_KOMPUTE GGML_KOMPUTE)
llama_option_depr(WARNING LLAMA_METAL GGML_METAL) llama_option_depr(WARNING LLAMA_METAL GGML_METAL)
llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY) llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY)
llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE) llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE)

View File

@ -4408,9 +4408,6 @@ class Gemma3NModel(Gemma3Model):
] ]
def set_vocab(self): def set_vocab(self):
with open(self.dir_model / "chat_template.jinja") as f:
# quick hack to make sure chat template is added
self.gguf_writer.add_chat_template(f.read())
super().set_vocab() super().set_vocab()
def set_gguf_parameters(self): def set_gguf_parameters(self):

View File

@ -181,7 +181,6 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
@ -266,7 +265,6 @@ set(GGML_PUBLIC_HEADERS
include/ggml-cann.h include/ggml-cann.h
include/ggml-cpp.h include/ggml-cpp.h
include/ggml-cuda.h include/ggml-cuda.h
include/ggml-kompute.h
include/ggml-opt.h include/ggml-opt.h
include/ggml-metal.h include/ggml-metal.h
include/ggml-rpc.h include/ggml-rpc.h

View File

@ -1,50 +0,0 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_KOMPUTE_MAX_DEVICES 16
struct ggml_vk_device {
int index;
int type; // same as VkPhysicalDeviceType
size_t heapSize;
const char * name;
const char * vendor;
int subgroupSize;
uint64_t bufferAlignment;
uint64_t maxAlloc;
};
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
bool ggml_vk_has_vulkan(void);
bool ggml_vk_has_device(void);
struct ggml_vk_device ggml_vk_current_device(void);
//
// backend API
//
// forward declaration
typedef struct ggml_backend * ggml_backend_t;
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
#ifdef __cplusplus
}
#endif

View File

@ -1986,12 +1986,13 @@ extern "C" {
// q: [n_embd_k, n_batch, n_head, ne3 ] // q: [n_embd_k, n_batch, n_head, ne3 ]
// k: [n_embd_k, n_kv, n_head_kv, ne3 ] // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
// //
// broadcast: // broadcast:
// n_head % n_head_kv == 0 // n_head % n_head_kv == 0
// ne3 % ne32 == 0 // n_head % ne32 == 0
// ne3 % ne33 == 0
// //
GGML_API struct ggml_tensor * ggml_flash_attn_ext( GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx, struct ggml_context * ctx,

View File

@ -365,7 +365,6 @@ ggml_add_backend(BLAS)
ggml_add_backend(CANN) ggml_add_backend(CANN)
ggml_add_backend(CUDA) ggml_add_backend(CUDA)
ggml_add_backend(HIP) ggml_add_backend(HIP)
ggml_add_backend(Kompute)
ggml_add_backend(METAL) ggml_add_backend(METAL)
ggml_add_backend(MUSA) ggml_add_backend(MUSA)
ggml_add_backend(RPC) ggml_add_backend(RPC)

View File

@ -61,10 +61,6 @@
#include "ggml-cann.h" #include "ggml-cann.h"
#endif #endif
#ifdef GGML_USE_KOMPUTE
#include "ggml-kompute.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8 // disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic push # pragma clang diagnostic push
@ -189,9 +185,6 @@ struct ggml_backend_registry {
#ifdef GGML_USE_RPC #ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg()); register_backend(ggml_backend_rpc_reg());
#endif #endif
#ifdef GGML_USE_KOMPUTE
register_backend(ggml_backend_kompute_reg());
#endif
#ifdef GGML_USE_CPU #ifdef GGML_USE_CPU
register_backend(ggml_backend_cpu_reg()); register_backend(ggml_backend_cpu_reg());
#endif #endif
@ -575,7 +568,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("cann", silent, dir_path); ggml_backend_load_best("cann", silent, dir_path);
ggml_backend_load_best("cuda", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path); ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path);

View File

@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return false; return false;
} }
} break; } break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY: { case GGML_OP_CPY: {
ggml_tensor *src = op->src[0]; ggml_tensor *src = op->src[0];
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||

View File

@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memset(VKQ32, 0, DV*sizeof(float)); memset(VKQ32, 0, DV*sizeof(float));
} }
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL; const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
// k indices // k indices
const int ik3 = iq3 / rk3; const int ik3 = iq3 / rk3;

View File

@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif #endif
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
const int id = ggml_cuda_get_device(); \
if (!shared_memory_limit_raised[id]) { \
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
shared_memory_limit_raised[id] = true; \
} \
} while (0)
#else
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x) #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else #else

View File

@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x); ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
if (nbytes_shared <= smpbo) { if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else { } else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) { if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00); cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else { } else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00); cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);

View File

@ -3390,7 +3390,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false; return false;
} }
// TODO: support broadcast // TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435 // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) { if (op->src[0]->ne[3] != 1) {
return false; return false;
} }

View File

@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc); const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;

View File

@ -2,6 +2,7 @@
#include "ggml.h" #include "ggml.h"
#include "softmax.cuh" #include "softmax.cuh"
#include <cstdint> #include <cstdint>
#include <utility>
template <typename T> template <typename T>
static __device__ __forceinline__ float t2f32(T val) { static __device__ __forceinline__ float t2f32(T val) {
@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
} }
} }
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
auto launch_kernel = [=](auto I) -> bool {
constexpr int ncols = decltype(I)::value;
constexpr int block = (ncols > 1024 ? 1024 : ncols);
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, p);
return true;
}
return false;
};
// unary fold over launch_kernel
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
return;
}
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
}
template<typename T> template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE; int nth = WARP_SIZE;
@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
// FIXME: this limit could be raised by ~2-4x on Ampere or newer const int id = ggml_cuda_get_device();
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>> if (nbytes_shared <= smpbo) {
(x, mask, dst, params); launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
}
} else { } else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params); soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);

View File

@ -1,166 +0,0 @@
find_package(Vulkan COMPONENTS glslc REQUIRED)
find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
if (NOT glslc_executable)
message(FATAL_ERROR "glslc not found")
endif()
ggml_add_backend_library(ggml-kompute
ggml-kompute.cpp
../../include/ggml-kompute.h
)
target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
target_include_directories(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
function(compile_shader)
set(options)
set(oneValueArgs)
set(multiValueArgs SOURCES)
cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
foreach(source ${compile_shader_SOURCES})
get_filename_component(filename ${source} NAME)
set(spv_file ${filename}.spv)
add_custom_command(
OUTPUT ${spv_file}
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
COMMENT "Compiling ${source} to ${spv_file}"
)
get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
set(FILE_NAME "shader${RAW_FILE_NAME}")
string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
if(CMAKE_GENERATOR MATCHES "Visual Studio")
add_custom_command(
OUTPUT ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
DEPENDS ${spv_file} xxd
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
)
else()
add_custom_command(
OUTPUT ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
DEPENDS ${spv_file} xxd
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
)
endif()
endforeach()
endfunction()
if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
message(STATUS "Kompute found")
set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
add_subdirectory(kompute)
# Compile our shaders
compile_shader(SOURCES
kompute-shaders/op_scale.comp
kompute-shaders/op_scale_8.comp
kompute-shaders/op_add.comp
kompute-shaders/op_addrow.comp
kompute-shaders/op_mul.comp
kompute-shaders/op_silu.comp
kompute-shaders/op_relu.comp
kompute-shaders/op_gelu.comp
kompute-shaders/op_softmax.comp
kompute-shaders/op_norm.comp
kompute-shaders/op_rmsnorm.comp
kompute-shaders/op_diagmask.comp
kompute-shaders/op_mul_mat_mat_f32.comp
kompute-shaders/op_mul_mat_f16.comp
kompute-shaders/op_mul_mat_q8_0.comp
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q4_k.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
kompute-shaders/op_getrows_q4_0.comp
kompute-shaders/op_getrows_q4_1.comp
kompute-shaders/op_getrows_q6_k.comp
kompute-shaders/op_rope_norm_f16.comp
kompute-shaders/op_rope_norm_f32.comp
kompute-shaders/op_rope_neox_f16.comp
kompute-shaders/op_rope_neox_f32.comp
kompute-shaders/op_cpy_f16_f16.comp
kompute-shaders/op_cpy_f16_f32.comp
kompute-shaders/op_cpy_f32_f16.comp
kompute-shaders/op_cpy_f32_f32.comp
)
# Create a custom target for our generated shaders
add_custom_target(generated_shaders DEPENDS
shaderop_scale.h
shaderop_scale_8.h
shaderop_add.h
shaderop_addrow.h
shaderop_mul.h
shaderop_silu.h
shaderop_relu.h
shaderop_gelu.h
shaderop_softmax.h
shaderop_norm.h
shaderop_rmsnorm.h
shaderop_diagmask.h
shaderop_mul_mat_mat_f32.h
shaderop_mul_mat_f16.h
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q4_k.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h
shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h
shaderop_getrows_q6_k.h
shaderop_rope_norm_f16.h
shaderop_rope_norm_f32.h
shaderop_rope_neox_f16.h
shaderop_rope_neox_f32.h
shaderop_cpy_f16_f16.h
shaderop_cpy_f16_f32.h
shaderop_cpy_f32_f16.h
shaderop_cpy_f32_f32.h
)
# Create a custom command that depends on the generated_shaders
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
DEPENDS generated_shaders
COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
)
# Add the stamp to the main sources to ensure dependency tracking
target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
else()
message(WARNING "Kompute not found")
endif()

File diff suppressed because it is too large Load Diff

@ -1 +0,0 @@
Subproject commit 4565194ed7c32d1d2efa32ceab4d3c6cae006306

View File

@ -1,112 +0,0 @@
#extension GL_EXT_shader_16bit_storage: require
#extension GL_EXT_shader_8bit_storage: require
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
#extension GL_EXT_control_flow_attributes: enable
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
#define QK4_0 32
#define QK4_1 32
#define GELU_COEF_A 0.044715
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
#define TWOPI_F 6.283185307179586f
#define QK_K 256
#define K_SCALE_SIZE 12
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
#define sizeof_block_q4_0 0x12
struct block_q4_0 {
float16_t d;
uint8_t qs[QK4_0 / 2];
};
mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
const float d2 = d1 / 256.f;
const float md = -8.f * xb.d;
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
const uint16_t mask1 = mask0 << 8;
mat4 reg;
for (int i=0;i<8;i++) {
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
}
return reg;
}
#define sizeof_block_q4_1 0x14
struct block_q4_1 {
float16_t d;
float16_t m;
uint8_t qs[QK4_1 / 2];
};
mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
const float d2 = d1 / 256.f;
const float m = xb.m;
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
const uint16_t mask1 = mask0 << 8;
mat4 reg;
for (int i=0;i<8;i++) {
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
}
return reg;
}
#define sizeof_block_q4_k 144
struct block_q4_k {
float16_t d;
float16_t dmin;
uint8_t scales[K_SCALE_SIZE];
uint8_t qs[QK_K/2];
};
#define sizeof_block_q6_k 210
struct block_q6_k {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
float16_t d; // super-block scale
};
mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
const float16_t d_all = xb.d;
const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
const uint qhIndex = 32*(il/8) + 16*(il&1);
float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F);
const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f);
const float16_t ml = float16_t(d_all * sc * 32.f);
const float16_t dl = float16_t(d_all * sc * coef);
mat4 reg;
for (int i = 0; i < 16; ++i) {
const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
: ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
reg[i/4][i%4] = dl * q - ml;
}
return reg;
}
#define QK8_0 32
// struct block_q8_0 {
// float16_t d; // delta
// int8_t qs[QK8_0]; // quants
// };
#define sizeof_block_q8_0 34

View File

@ -1,58 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1024) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb00;
int nb01;
int nb02;
int nb03;
int ne10;
int ne11;
int ne12;
int ne13;
int nb10;
int nb11;
int nb12;
int nb13;
int ne0;
int nb0;
int nb1;
int nb2;
int nb3;
//int offs; // TODO: needed for GGML_OP_ACC, see metal code
} pcs;
// general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint i13 = i03 % pcs.ne13;
const uint i12 = i02 % pcs.ne12;
const uint i11 = i01 % pcs.ne11;
int offs = 0; // TMP (see above)
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4);
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4);
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
const uint i10 = i0 % pcs.ne10;
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
}
}

View File

@ -1,25 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
uint row;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float16_t
#define IN_TYPE_SIZE 2
#define OUT_TYPE float16_t
#define OUT_TYPE_SIZE 2
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float16_t
#define IN_TYPE_SIZE 2
#define OUT_TYPE float
#define OUT_TYPE_SIZE 4
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float
#define IN_TYPE_SIZE 4
#define OUT_TYPE float16_t
#define OUT_TYPE_SIZE 2
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float
#define IN_TYPE_SIZE 4
#define OUT_TYPE float
#define OUT_TYPE_SIZE 4
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}

View File

@ -1,30 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
uint n_past;
int ne00;
int ne01;
} pcs;
void main() {
const uint i02 = gl_WorkGroupID.z;
const uint i01 = gl_WorkGroupID.y;
const uint i00 = gl_WorkGroupID.x;
const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
if (i00 > pcs.n_past + i01) {
out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
} else {
out_[index + pcs.outOff] = in_[index + pcs.inOff];
}
}

View File

@ -1,22 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
}
}

View File

@ -1,17 +0,0 @@
void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];
int z = 0;
for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
const mat4 result = dequantize_block(inIndex, ind%NL);
for (uint j = 0; j < 4; ++j) {
for (uint k = 0; k < 4; ++k) {
const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
out_[outIndex] = result[j][k];
++z;
}
}
}
}

View File

@ -1,31 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
for (int j = 0; j < k; j++) {
out_[y + j] = inA[x + j];
}
}
void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];
dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
}

View File

@ -1,31 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
for (int j = 0; j < k; j++) {
out_[y + j] = inA[x + j];
}
}
void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];
dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
}

View File

@ -1,38 +0,0 @@
#version 450
#include "common.comp"
#define NL 2
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q4_0
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
block_q4_0 get_unaligned_block_q4_0(uint index) {
block_q4_0 fres;
fres.d = u8BufToFloat16(inA, index);
[[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
fres.qs[it] = inA[index+2+it];
}
return fres;
}
mat4 dequantize_block(uint index, uint il) {
const block_q4_0 block = get_unaligned_block_q4_0(index);
return dequantize_q4_0(block, il);
}
#include "op_getrows.comp"

View File

@ -1,39 +0,0 @@
#version 450
#include "common.comp"
#define NL 2
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q4_1
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
block_q4_1 get_unaligned_block_q4_1(uint index) {
block_q4_1 fres;
fres.d = u8BufToFloat16(inA, index);
fres.m = u8BufToFloat16(inA, index+2);
[[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
fres.qs[it] = inA[index+4+it];
}
return fres;
}
mat4 dequantize_block(uint index, uint il) {
const block_q4_1 block = get_unaligned_block_q4_1(index);
return dequantize_q4_1(block, il);
}
#include "op_getrows.comp"

View File

@ -1,44 +0,0 @@
#version 450
#include "common.comp"
#define NL 16
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q6_k
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
block_q6_k get_unaligned_block_q6_k(uint index) {
block_q6_k fres;
[[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
fres.ql[it] = inA[index + it];
}
[[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
fres.qh[it] = inA[index + QK_K/2 + it];
}
[[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
}
fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
return fres;
}
mat4 dequantize_block(uint index, uint il) {
const block_q6_k block = get_unaligned_block_q6_k(index);
return dequantize_q6_k(block, il);
}
#include "op_getrows.comp"

View File

@ -1,52 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1024) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb00;
int nb01;
int nb02;
int nb03;
int ne10;
int ne11;
int ne12;
int ne13;
int nb10;
int nb11;
int nb12;
int nb13;
int ne0;
int nb0;
int nb1;
int nb2;
int nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint i13 = i03 % pcs.ne13;
const uint i12 = i02 % pcs.ne12;
const uint i11 = i01 % pcs.ne11;
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4);
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
const uint i10 = i0 % pcs.ne10;
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
}
}

View File

@ -1,69 +0,0 @@
#version 450
#include "common.comp"
#extension GL_KHR_shader_subgroup_arithmetic : require
layout(local_size_x_id = 0) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne10;
int ne11;
int ne12;
uint nb10;
uint nb11;
uint nb12;
uint nb13;
int ne0;
int ne1;
uint r2;
uint r3;
} pcs;
#define N_F16_F32 4
void main() {
const uint r0 = gl_WorkGroupID.x;
const uint rb = gl_WorkGroupID.y*N_F16_F32;
const uint im = gl_WorkGroupID.z;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
for (uint row = 0; row < N_F16_F32; ++row) {
uint r1 = rb + row;
if (r1 >= pcs.ne11) {
break;
}
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf = 0;
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
sumf += float(inA[x+i]) * float(inB[y+i]);
}
const float all_sum = subgroupAdd(sumf);
if (subgroupElect()) {
out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
}
}
}

View File

@ -1,51 +0,0 @@
#version 450
#include "common.comp"
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
// device subgroup size
layout (local_size_x_id = 0) in;
layout(binding = 0) readonly buffer tensorInA { float inA[]; };
layout(binding = 1) readonly buffer tensorInB { float inB[]; };
layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
layout(push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
int ne11;
int ne12;
uint nb01;
uint nb02;
uint nb11;
uint nb12;
uint nb1;
uint nb2;
}
pcs;
void main() {
uvec3 gid = gl_WorkGroupID;
uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
float sum = 0.0f;
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
sum += float(inA[x+i]) * float(inB[y+i]);
}
const float all_sum = subgroupAdd(sum);
if (subgroupElect()) {
out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
}
}

View File

@ -1,33 +0,0 @@
#version 450
#include "common.comp"
#define BLOCKS_IN_QUANT QK4_0
#define SIZE_OF_BLOCK sizeof_block_q4_0
#define N_ROWS 4
#include "op_mul_mv_q_n_pre.comp"
// The q4_0 version of this function
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
vec2 acc = vec2(0.0, 0.0);
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
float d = float(u8BufToFloat16(inA, index));
float sumy = 0.0f;
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
const float yl0 = inB[yb + i];
const float yl1 = inB[yb + i + 1];
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
sumy += yl0 + yl1 + yl8 + yl9;
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
}
return d * (sumy * -8.f + acc[0] + acc[1]);
}
#include "op_mul_mv_q_n.comp"

View File

@ -1,35 +0,0 @@
#version 450
#include "common.comp"
#define BLOCKS_IN_QUANT QK4_1
#define SIZE_OF_BLOCK sizeof_block_q4_1
#define N_ROWS 4
#include "op_mul_mv_q_n_pre.comp"
// The q4_1 version of this function
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
vec2 acc = vec2(0.0, 0.0);
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
float d = float(u8BufToFloat16(inA, index));
float m = float(u8BufToFloat16(inA, index+2));
float sumy = 0.0f;
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
const float yl0 = inB[yb + i];
const float yl1 = inB[yb + i + 1];
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
sumy += yl0 + yl1 + yl8 + yl9;
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
}
return d * (acc[0] + acc[1]) + sumy * m;
}
#include "op_mul_mv_q_n.comp"

View File

@ -1,140 +0,0 @@
#version 450
#include "common.comp"
#define N_DST 4
#define SIZE_OF_BLOCK sizeof_block_q4_k
layout(local_size_x = 4) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
void main() {
const uint16_t kmask1 = uint16_t(0x3f3f);
const uint16_t kmask2 = uint16_t(0x0f0f);
const uint16_t kmask3 = uint16_t(0xc0c0);
const uint ix = gl_SubgroupInvocationID/8; // 0...3
const uint it = gl_SubgroupInvocationID%8; // 0...7
const uint iq = it/4; // 0 or 1
const uint ir = it%4; // 0...3
const uint nb = pcs.ne00/QK_K;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = r0 * N_DST;
const uint ib_row = first_row * nb;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;
const uint xblk = offset0 + pcs.inAOff;
const uint y = (offset1 / 4) + pcs.inBOff;
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
float all_sum = 0.f;
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
for (uint ib = ix; ib < nb; ib += 4) {
const uint blk_idx = ib + xblk;
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
}
for (int row = 0; row < N_DST; row++) {
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
uint16_t sc16[4];
sc16[0] = sc_0 & kmask1;
sc16[1] = sc_2 & kmask1;
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
acc1[0] += yl[i+0] * (q1 & 0x000F);
acc1[1] += yl[i+1] * (q1 & 0x0F00);
acc1[2] += yl[i+8] * (q1 & 0x00F0);
acc1[3] += yl[i+9] * (q1 & 0xF000);
acc2[0] += yh[i+0] * (q2 & 0x000F);
acc2[1] += yh[i+1] * (q2 & 0x0F00);
acc2[2] += yh[i+8] * (q2 & 0x00F0);
acc2[3] += yh[i+9] * (q2 & 0xF000);
}
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
float dall = float(inA[blk_idx + row_idx].d);
float dmin = float(inA[blk_idx + row_idx].dmin);
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
}
y4 += 4 * QK_K;
}
for (int row = 0; row < N_DST; ++row) {
all_sum = subgroupAdd(sumf[row]);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
}
}
}

View File

@ -1,106 +0,0 @@
#version 450
#include "common.comp"
#define SIZE_OF_BLOCK sizeof_block_q6_k
layout(local_size_x_id = 0) in;
layout(local_size_y_id = 1) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
void main() {
const uint8_t kmask1 = uint8_t(0x03);
const uint8_t kmask2 = uint8_t(0x0C);
const uint8_t kmask3 = uint8_t(0x30);
const uint8_t kmask4 = uint8_t(0xC0);
const uint nb = pcs.ne00/QK_K;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf = 0;
// bits of invocation ID for gl_SubgroupSize=32:
// x x x x x
// 4 3 2 1 0
// ( tid ) ix
// ip ( il )
const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
const uint ip = tid/8; // first or second half of block (0 or 1)
const uint il = tid%8; // each half has 8 parts, one per scale
const uint n = 4; // 4 scales at a time (and 4 sums)
const uint l0 = n*il; // offset into half-block, 0..28
const uint is = 8*ip + l0/16; // 0, 1, 8, 9
const uint y_offset = 128*ip + l0;
const uint q_offset_l = 64*ip + l0;
const uint q_offset_h = 32*ip + l0;
for (uint i = ix; i < nb; i += block_stride) {
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
const uint qlIndex = q_offset_l;
const uint q2Index = qlIndex + QK_K/8;
const uint qhIndex = q_offset_h;
const uint y = yy + i * QK_K + y_offset;
float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
for (uint l = 0; l < n; ++l) {
const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32);
sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32);
}
float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
}
const float tot = subgroupAdd(sumf);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
}
}

View File

@ -1,73 +0,0 @@
#version 450
#include "common.comp"
#include "op_mul_mv_q_n_pre.comp"
#define SIZE_OF_D 2
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
#define NB_Q8_0 8
void main() {
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
if (gl_SubgroupInvocationID > 31)
return;
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
const int nb = pcs.ne00/QK8_0;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
float yl[NB_Q8_0];
float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
const uint ix = gl_SubgroupInvocationID.x/4;
const uint il = gl_SubgroupInvocationID.x%4;
uint yb = y + ix * QK8_0 + NB_Q8_0*il;
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (uint ib = ix; ib < nb; ib += nw/4) {
for (int i = 0; i < NB_Q8_0; ++i) {
yl[i] = inB[yb + i];
}
for (int row = 0; row < nr; row++) {
const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
sumq += qs_iq * yl[iq];
}
const float16_t d = u8BufToFloat16(inA, x + block_offset);
sumf[row] += sumq*d;
}
yb += NB_Q8_0 * nw;
}
for (int row = 0; row < nr; ++row) {
const float tot = subgroupAdd(sumf[row]);
if (subgroupElect() && first_row + row < pcs.ne01) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
}
}
}

View File

@ -1,52 +0,0 @@
void main() {
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
if (gl_SubgroupInvocationID > 31)
return;
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
// pointers to src0 rows
uint ax[N_ROWS];
for (int row = 0; row < N_ROWS; ++row) {
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
ax[row] = offset0 + pcs.inAOff;
}
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
const uint ix = gl_SubgroupInvocationID/2;
const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
uint yb = y + ix * BLOCKS_IN_QUANT + il;
//debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
// gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
// gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
for (uint ib = ix; ib < nb; ib += 16) {
for (int row = 0; row < N_ROWS; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
}
yb += BLOCKS_IN_QUANT * 16;
}
for (int row = 0; row < N_ROWS; ++row) {
const float tot = subgroupAdd(sumf[row]);
if (first_row + row < pcs.ne01 && subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
}
}
}

View File

@ -1,28 +0,0 @@
layout(local_size_x_id = 0) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
int ne10;
int ne12;
int ne0;
int ne1;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;

View File

@ -1,84 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 256) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
uint ne00;
uint nb01;
float eps;
} pcs;
shared float sum[gl_WorkGroupSize.x];
void main() {
const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
// MEAN
// parallel sum
sum[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
sum[gl_LocalInvocationID.x] += in_[x+i00];
}
// reduce
barrier();
memoryBarrierShared();
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast
if (gl_LocalInvocationID.x == 0) {
sum[0] /= float(pcs.ne00);
}
barrier();
memoryBarrierShared();
const float mean = sum[0];
// recenter
const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
out_[y+i00] = in_[x+i00] - mean;
}
// VARIANCE
// parallel sum
sum[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
}
// reduce
barrier();
memoryBarrierShared();
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast
if (gl_LocalInvocationID.x == 0) {
sum[0] /= float(pcs.ne00);
}
barrier();
memoryBarrierShared();
const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + pcs.eps);
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
out_[y+i00] *= scale;
}
}

View File

@ -1,21 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
}
}

View File

@ -1,53 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 512) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
uint ne00;
uint nb01;
float eps;
} pcs;
shared float sum[gl_WorkGroupSize.x];
void main() {
const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
// parallel sum
sum[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
}
// reduce
barrier();
memoryBarrierShared();
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast
if (gl_LocalInvocationID.x == 0) {
sum[0] /= float(pcs.ne00);
}
barrier();
memoryBarrierShared();
const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
out_[y+i00] = in_[x+i00] * scale;
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -1,19 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
float scale;
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}

View File

@ -1,23 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
float scale;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
}

View File

@ -1,22 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
}
}

View File

@ -1,72 +0,0 @@
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
#version 450
#include "common.comp"
layout(local_size_x_id = 0) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
int mask;
} pcs;
void main() {
if (gl_SubgroupInvocationID > 31)
return;
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
const uint pdst = extra_off + pcs.outOff; // Based from out_
float slope = 1.0f;
// ALiBi
if (pcs.max_bias > 0.0f) {
int64_t h = i02;
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
slope = pow(base, float(exp));
}
// parallel max
float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
}
float max_ = subgroupMax(localMax);
// parallel sum
float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0;
}
const float sum = subgroupAdd(localSum);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
out_[pdst + i00] /= sum;
}
}

View File

@ -1,71 +0,0 @@
#include "common.comp"
#define GGML_ROPE_TYPE_NEOX 2
// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint inCOff;
uint outOff;
int n_dims;
int mode;
int n_ctx_orig;
float freq_base;
float freq_scale;
bool has_freq_factors;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
float rope_yarn_ramp(const float low, const float high, const float i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
out float cos_theta, out float sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
}
void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}

View File

@ -230,8 +230,10 @@ typedef struct {
uint64_t nb22; uint64_t nb22;
uint64_t nb23; uint64_t nb23;
int32_t ne32; int32_t ne32;
int32_t ne33;
uint64_t nb31; uint64_t nb31;
uint64_t nb32; uint64_t nb32;
uint64_t nb33;
int32_t ne1; int32_t ne1;
int32_t ne2; int32_t ne2;
float scale; float scale;

View File

@ -5018,8 +5018,10 @@ static bool ggml_metal_encode_node(
/*.nb22 =*/ nb22, /*.nb22 =*/ nb22,
/*.nb23 =*/ nb23, /*.nb23 =*/ nb23,
/*.ne32 =*/ ne32, /*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31, /*.nb31 =*/ nb31,
/*.nb32 =*/ nb32, /*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1, /*.ne1 =*/ ne1,
/*.ne2 =*/ ne2, /*.ne2 =*/ ne2,
/*.scale =*/ scale, /*.scale =*/ scale,

View File

@ -3857,7 +3857,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory // load the mask in shared memory
#pragma unroll(Q) #pragma unroll(Q)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32); device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
const float m = pm[ic + tiisg]; const float m = pm[ic + tiisg];
@ -4343,7 +4343,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q; const bool has_mask = mask != q;
// pointer to the mask // pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32); device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
float slope = 1.0f; float slope = 1.0f;

View File

@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default: default:
return false; return false;
} }
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CONT: case GGML_OP_CONT:
@ -5757,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
const int ne00 = src0 ? src0->ne[0] : 0; const int ne00 = src0->ne[0];
const int ne01 = src0 ? src0->ne[1] : 0; const int ne01 = src0->ne[1];
const int ne02 = src0 ? src0->ne[2] : 0; const int ne02 = src0->ne[2];
const int ne03 = src0 ? src0->ne[3] : 0; const int ne03 = src0->ne[3];
const cl_long nb01 = src0->nb[1];
const cl_long nb02 = src0->nb[2];
const cl_long nb03 = src0->nb[3];
const int ne12 = src1 ? src1->ne[2] : 0;
const int ne13 = src1 ? src1->ne[3] : 0;
const cl_long nb11 = src1 ? src1->nb[1] : 0;
const cl_long nb12 = src1 ? src1->nb[2] : 0;
const cl_long nb13 = src1 ? src1->nb[3] : 0;
const cl_long nb1 = dst->nb[1];
const cl_long nb2 = dst->nb[2];
const cl_long nb3 = dst->nb[3];
float scale, max_bias; float scale, max_bias;
memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float));
const int nrows_x = ggml_nrows(src0); const int n_head = src0->ne[2];
const int nrows_y = src0->ne[1];
const int n_head = nrows_x/nrows_y;
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@ -5814,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1}; size_t local_work_size[] = {(size_t)nth, 1, 1};

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max_4_f16( kernel void kernel_soft_max_4_f16(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global half * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float *)((global char *)src0 + offset0); src0 = src0 + offset0;
src1 = (global half *)((global char *)src1 + offset1); src1 = src1 + offset1;
dst = (global float *)((global char *)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i13 = i03%ne13;
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; int i12 = i02%ne12;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max_4( kernel void kernel_soft_max_4(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global float * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float*)((global char*)src0 + offset0); src0 = src0 + offset0;
src1 = (global float*)((global char*)src1 + offset1); src1 = src1 + offset1;
dst = (global float*)((global char*)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i13 = i03%ne13;
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; int i12 = i02%ne12;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max_f16( kernel void kernel_soft_max_f16(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global half * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float *)((global char *)src0 + offset0); src0 = src0 + offset0;
src1 = (global half *)((global char *)src1 + offset1); src1 = src1 + offset1;
dst = (global float *)((global char *)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i13 = i03%ne13;
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; int i12 = i02%ne12;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i11 = i01;
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max( kernel void kernel_soft_max(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global float * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float*)((global char*)src0 + offset0); src0 = src0 + offset0;
src1 = (global float*)((global char*)src1 + offset1); src1 = src1 + offset1;
dst = (global float*)((global char*)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i13 = i03%ne13;
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; int i12 = i02%ne12;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i11 = i01;
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -83,7 +83,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc = info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version(); 100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.max_work_group_sizes[i] = prop.get_max_work_group_size();
} }
@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false; return false;
} }
} }
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;

View File

@ -224,6 +224,21 @@ enum vk_device_architecture {
INTEL_XE2, INTEL_XE2,
}; };
// HSK x HSV
enum FaHeadSizes {
FA_HEAD_SIZE_64,
FA_HEAD_SIZE_80,
FA_HEAD_SIZE_96,
FA_HEAD_SIZE_112,
FA_HEAD_SIZE_128,
FA_HEAD_SIZE_192,
FA_HEAD_SIZE_192_128,
FA_HEAD_SIZE_256,
FA_HEAD_SIZE_576_512,
FA_HEAD_SIZE_UNSUPPORTED,
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
};
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties(); vk::PhysicalDeviceProperties props = device.getProperties();
@ -467,26 +482,11 @@ struct vk_device_struct {
vk_pipeline pipeline_conv2d_dw_cwhn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_flash_attn_split_k_reduce;
@ -1003,7 +1003,7 @@ struct ggml_backend_vk_context {
// number of additional consecutive nodes that are being fused with the // number of additional consecutive nodes that are being fused with the
// node currently being processed // node currently being processed
uint32_t num_additional_fused_ops {}; int num_additional_fused_ops {};
}; };
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@ -1699,6 +1699,35 @@ enum FaCodePath {
FA_COOPMAT2, FA_COOPMAT2,
}; };
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
if (hsk != 192 && hsk != 576 && hsk != hsv) {
return FA_HEAD_SIZE_UNSUPPORTED;
}
switch (hsk) {
case 64: return FA_HEAD_SIZE_64;
case 80: return FA_HEAD_SIZE_80;
case 96: return FA_HEAD_SIZE_96;
case 112: return FA_HEAD_SIZE_112;
case 128: return FA_HEAD_SIZE_128;
case 192:
if (hsv == 192) {
return FA_HEAD_SIZE_192;
} else if (hsv == 128) {
return FA_HEAD_SIZE_192_128;
} else {
return FA_HEAD_SIZE_UNSUPPORTED;
}
case 256: return FA_HEAD_SIZE_256;
case 576:
if (hsv == 512) {
return FA_HEAD_SIZE_576_512;
} else {
return FA_HEAD_SIZE_UNSUPPORTED;
}
default: return FA_HEAD_SIZE_UNSUPPORTED;
}
}
// number of rows/cols for flash attention shader // number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
@ -1719,8 +1748,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
} }
} }
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp); GGML_UNUSED(clamp);
GGML_UNUSED(hsv);
if (path == FA_SCALAR) { if (path == FA_SCALAR) {
if (small_rows) { if (small_rows) {
@ -1744,7 +1774,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
} }
// small cols to reduce register count // small cols to reduce register count
if (ggml_is_quantized(type) || D == 256) { if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32}; return {64, 32};
} }
return {64, 64}; return {64, 64};
@ -2037,19 +2067,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
}; };
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
}; };
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best. // For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80. // can't use 256 for D==80.
// For scalar, use 128 (arbitrary) // For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size ? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128); : ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@ -2058,26 +2090,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
}; };
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@ -3688,7 +3723,6 @@ static void ggml_vk_instance_init() {
} }
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@ -6002,24 +6036,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
// Needs to be kept up to date on shader changes // Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows; const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
const uint32_t masksh = Bc * Br * sizeof(float);
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t acctype = f32acc ? 4 : 2; const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8; const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype; const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype; const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = D / 4 + 2; const uint32_t kshstride = hsk / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4; const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float); const uint32_t slope = Br * sizeof(float);
@ -6027,7 +6084,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
return supported; return supported;
} }
@ -6051,11 +6108,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t D = neq0; const uint32_t HSK = nek0;
const uint32_t HSV = nev0;
uint32_t N = neq1; uint32_t N = neq1;
const uint32_t KV = nek1; const uint32_t KV = nek1;
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == HSV);
GGML_ASSERT(ne2 == N); GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous // input tensor rows must be contiguous
@ -6063,12 +6121,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == D); GGML_ASSERT(neq0 == HSK);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(nev1 == nek1); GGML_ASSERT(nev1 == nek1);
@ -6089,7 +6144,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
if (!coopmat_shape_supported || !coopmat_shmem_supported) { if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR; path = FA_SCALAR;
@ -6142,47 +6197,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
path = FA_SCALAR; path = FA_SCALAR;
} }
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
if (path == FA_SCALAR &&
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
small_rows = true;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
switch (path) { switch (path) {
case FA_SCALAR: case FA_SCALAR:
switch (D) { pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break; break;
case FA_COOPMAT1: case FA_COOPMAT1:
switch (D) { pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break; break;
case FA_COOPMAT2: case FA_COOPMAT2:
switch (D) { pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break; break;
default: default:
GGML_ASSERT(0); GGML_ASSERT(0);
@ -6212,7 +6245,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Try to use split_k when KV is large enough to be worth the overhead // Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM. // Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z); split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) { if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple // Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that. // of "align", so recompute split_k based on that.
@ -6224,7 +6257,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) { if (split_k_size > ctx->device->max_memory_allocation_size) {
GGML_ABORT("Requested preallocation size is too large"); GGML_ABORT("Requested preallocation size is too large");
} }
@ -6342,7 +6375,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k }; const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{ {
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
@ -10241,19 +10274,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device); auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2; bool coopmat2 = device->coopmat2;
switch (op->src[0]->ne[0]) { FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
case 64: if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
case 80:
case 96:
case 112:
case 128:
case 256:
break;
default:
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false; return false;
} }
if (op->src[0]->type != GGML_TYPE_F32) { if (op->src[0]->type != GGML_TYPE_F32) {
@ -10265,6 +10287,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false; return false;
} }
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
return false;
}
// It's straightforward to support different K/V dequant, but would // It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines // significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) { if (op->src[1]->type != op->src[2]->type) {
@ -10333,6 +10361,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false; return false;
} }
} break; } break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CONT: case GGML_OP_CONT:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_DUP: case GGML_OP_DUP:

View File

@ -11,7 +11,8 @@
#include "types.comp" #include "types.comp"
#include "flash_attn_base.comp" #include "flash_attn_base.comp"
const uint32_t D_per_thread = D / D_split; const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter; const uint32_t cols_per_thread = Bc / cols_per_iter;
@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{ {
uint32_t offset = (iq2 + r) * D + c; uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem); data_o[o_offset + offset] = D_TYPE(elem);
return elem; return elem;
} }
@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br]; shared float masksh[Bc][Br];
shared vec4 Qf[Br][D / 4]; shared vec4 Qf[Br][HSK / 4];
void main() { void main() {
#ifdef NEEDS_INIT_IQ_SHMEM #ifdef NEEDS_INIT_IQ_SHMEM
@ -53,18 +54,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (D / 4); uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < D / 4 && if (r < Br && d < HSK / 4 &&
i * Br + r < N) { i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
} }
} }
barrier(); barrier();
vec4 Of[Br][D_per_thread / 4]; vec4 Of[Br][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0); Of[r][d] = vec4(0.0);
} }
@ -116,7 +117,7 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -195,14 +196,14 @@ void main() {
Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d]; Of[r][d] = eMf[r] * Of[r][d];
} }
} }
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -259,7 +260,7 @@ void main() {
Lf[r] = tmpsh[d_tid]; Lf[r] = tmpsh[d_tid];
barrier(); barrier();
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d]; Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d]; tmpshv4[tid] = Of[r][d];
@ -281,11 +282,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final // If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values. // division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) { if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num); uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) { if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
} }
@ -293,7 +294,7 @@ void main() {
} }
} }
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) { if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@ -309,18 +310,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = 1.0 / Lf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r]; Of[r][d] *= Lfrcp[r];
} }
} }
uint32_t o_offset = iq3*p.ne2*p.ne1*D; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) { if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) { if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
} }
@ -330,9 +331,9 @@ void main() {
} else { } else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) { if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
} }
} }
} }

View File

@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32; layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t Clamp = 0; layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t D_split = 16; layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint32_t N; uint32_t N;

View File

@ -13,7 +13,9 @@
#include "types.comp" #include "types.comp"
#include "flash_attn_base.comp" #include "flash_attn_base.comp"
const uint32_t D_per_thread = D / D_split; const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4; const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split; const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{ {
uint32_t offset = (iq2 + r) * D + c; uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem); data_o[o_offset + offset] = D_TYPE(elem);
return elem; return elem;
} }
@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = D / 4 + 2; // in units of f16vec4 const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride]; shared f16vec4 Qf[Br * qstride];
// Avoid padding for D==256 to make it fit in 48KB shmem. // Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride]; shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride]; shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br]; shared float slope[Br];
@ -74,18 +76,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (D / 4); uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < D / 4 && if (r < Br && d < HSK / 4 &&
i * Br + r < N) { i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
} }
} }
barrier(); barrier();
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0); Of[r][d] = ACC_TYPEV4(0.0);
} }
@ -131,10 +133,10 @@ void main() {
[[dont_unroll]] [[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) { for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (D / 4); uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < D / 4) { if (c < Bc && d < HSK / 4) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -149,14 +151,14 @@ void main() {
} }
barrier(); barrier();
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it // This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < D / 16; ++d) { for (uint32_t d = 0; d < HSK / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@ -206,7 +208,7 @@ void main() {
eMf[r] = exp(Moldf - Mf[r]); eMf[r] = exp(Moldf - Mf[r]);
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d]; Of[r][d] = float16_t(eMf[r]) * Of[r][d];
} }
@ -221,7 +223,7 @@ void main() {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r]; Lf[r] += Pf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -284,7 +286,7 @@ void main() {
} }
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d]; Of[r][d] = float16_t(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d]; tmpshv4[tid] = Of[r][d];
@ -304,11 +306,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final // If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values. // division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) { if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num); uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) { if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
} }
@ -316,7 +318,7 @@ void main() {
} }
} }
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) { if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@ -332,18 +334,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = 1.0 / Lf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]); Of[r][d] *= float16_t(Lfrcp[r]);
} }
} }
uint32_t o_offset = iq3*p.ne2*p.ne1*D; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) { if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) { if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
} }
@ -353,9 +355,9 @@ void main() {
} else { } else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) { if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
} }
} }
} }

View File

@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{ {
if (r < N && c < D) { if (r < N && c < HSV) {
uint32_t offset = (iq2 + r) * D + c; uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem); data_o[o_offset + offset] = D_TYPE(elem);
} }
return elem; return elem;
@ -86,9 +86,9 @@ void main() {
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif #endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader // hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV) if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
@ -104,16 +104,16 @@ void main() {
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q; coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16; coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q); Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale); Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
@ -140,10 +140,10 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T; coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S); S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) { if (p.logit_softcap != 0.0f) {
@ -208,42 +208,42 @@ void main() {
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum); rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V; coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
L = eM*L + rowsum; L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise // This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared // multiply rather than matrix multiply it has the diagonal element smeared
// across the row // across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce // resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O. // multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV); PV = coopMatMulAdd(P_A, V, PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV); O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
} }
// If there is split_k, then the split_k resolve shader does the final // If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values. // division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) { if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num); uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return; return;
} }
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce // resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
@ -255,18 +255,18 @@ void main() {
O = Ldiag*O; O = Ldiag*O;
uint32_t o_offset = iq3*p.ne2*p.ne1*D; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) { if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else { } else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions // permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
} }
} }

View File

@ -3674,7 +3674,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
if (mask) { if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]); GGML_ASSERT(mask->ne[1] >= a->ne[1]);
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0); GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
@ -4704,12 +4703,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
if (mask) { if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0); GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
} }
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
@ -6051,13 +6050,28 @@ static void ggml_compute_backward(
} }
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break; } break;
case GGML_OP_GLU: {
switch (ggml_get_glu_op(tensor)) {
case GGML_GLU_OP_SWIGLU: {
if (src0_needs_grads) {
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
}
} break;
default: {
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
} //break;
}
} break;
case GGML_OP_NONE: { case GGML_OP_NONE: {
// noop // noop
} break; } break;
case GGML_OP_COUNT: case GGML_OP_COUNT:
default: { default: {
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
GGML_ABORT("fatal error");
} //break; } //break;
} }

View File

@ -714,8 +714,8 @@ class GGUFWriter:
def add_clamp_kqv(self, value: float) -> None: def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
def add_shared_kv_layers(self, value: float) -> None: def add_shared_kv_layers(self, value: int) -> None:
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)

View File

@ -245,9 +245,18 @@ class SpecialVocab:
if not tokenizer_config: if not tokenizer_config:
return True return True
chat_template_alt = None chat_template_alt = None
chat_template_file = path / 'chat_template.json' chat_template_json = path / 'chat_template.json'
if chat_template_file.is_file(): chat_template_jinja = path / 'chat_template.jinja'
with open(chat_template_file, encoding = 'utf-8') as f: if chat_template_jinja.is_file():
with open(chat_template_jinja, encoding = 'utf-8') as f:
chat_template_alt = f.read()
if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
for template_path in additional_templates:
with open(template_path, encoding = 'utf-8') as fp:
chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
elif chat_template_json.is_file():
with open(chat_template_json, encoding = 'utf-8') as f:
chat_template_alt = json.load(f).get('chat_template') chat_template_alt = json.load(f).get('chat_template')
chat_template = tokenizer_config.get('chat_template', chat_template_alt) chat_template = tokenizer_config.get('chat_template', chat_template_alt)
if chat_template is None or isinstance(chat_template, (str, list)): if chat_template is None or isinstance(chat_template, (str, list)):

View File

@ -83,7 +83,6 @@ while read c; do
src/ggml-cpu/* \ src/ggml-cpu/* \
src/ggml-cuda/* \ src/ggml-cuda/* \
src/ggml-hip/* \ src/ggml-hip/* \
src/ggml-kompute/* \
src/ggml-metal/* \ src/ggml-metal/* \
src/ggml-musa/* \ src/ggml-musa/* \
src/ggml-opencl/* \ src/ggml-opencl/* \
@ -141,7 +140,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml-cpu/* -> ggml/src/ggml-cpu/* # src/ggml-cpu/* -> ggml/src/ggml-cpu/*
# src/ggml-cuda/* -> ggml/src/ggml-cuda/* # src/ggml-cuda/* -> ggml/src/ggml-cuda/*
# src/ggml-hip/* -> ggml/src/ggml-hip/* # src/ggml-hip/* -> ggml/src/ggml-hip/*
# src/ggml-kompute/* -> ggml/src/ggml-kompute/*
# src/ggml-metal/* -> ggml/src/ggml-metal/* # src/ggml-metal/* -> ggml/src/ggml-metal/*
# src/ggml-musa/* -> ggml/src/ggml-musa/* # src/ggml-musa/* -> ggml/src/ggml-musa/*
# src/ggml-opencl/* -> ggml/src/ggml-opencl/* # src/ggml-opencl/* -> ggml/src/ggml-opencl/*
@ -174,7 +172,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \

View File

@ -15,7 +15,6 @@ cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/ cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/
cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/ cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/
cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/
cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/ cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/
cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/ cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/
cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/ cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/

View File

@ -281,20 +281,23 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
} }
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask); GGML_ASSERT(cross_kq_mask);
@ -332,7 +335,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
} }
void llm_graph_input_one::set_input(const llama_ubatch *) { void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
GGML_ASSERT(one && ggml_nelements(one) == 1); GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f; float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float)); ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@ -1155,8 +1159,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
const auto n_kv = mctx_cur->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1187,8 +1193,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
@ -1247,11 +1256,15 @@ ggml_tensor * llm_graph_context::build_attn(
// optionally store to KV cache // optionally store to KV cache
if (k_cur) { if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
} }
if (v_cur) { if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@ -1343,8 +1356,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{ {
const auto n_kv = mctx_cur->get_base()->get_n_kv(); const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1355,8 +1370,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
const auto n_kv = mctx_cur->get_swa()->get_n_kv(); const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa); ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

View File

@ -249,8 +249,14 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
@ -274,9 +280,19 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
@ -309,7 +325,7 @@ public:
llm_graph_input_one() {} llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default; virtual ~llm_graph_input_one() = default;
void set_input(const llama_ubatch *) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * one = nullptr; // F32 ggml_tensor * one = nullptr; // F32
}; };

View File

@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); auto sinfos_base = kv_base->prepare(ubatches);
if (heads_base.empty()) { if (sinfos_base.empty()) {
break; break;
} }
auto heads_swa = kv_swa->prepare(ubatches); auto sinfos_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) { if (sinfos_swa.empty()) {
break; break;
} }
assert(heads_base.size() == heads_swa.size()); assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// if it fails, try equal split // if it fails, try equal split
@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); auto sinfos_base = kv_base->prepare(ubatches);
if (heads_base.empty()) { if (sinfos_base.empty()) {
break; break;
} }
auto heads_swa = kv_swa->prepare(ubatches); auto sinfos_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) { if (sinfos_swa.empty()) {
break; break;
} }
assert(heads_base.size() == heads_swa.size()); assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// TODO: if we fail again, we should attempt different splitting strategies // TODO: if we fail again, we should attempt different splitting strategies
@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, slot_info_vec_t sinfos_base,
std::vector<uint32_t> heads_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }

View File

@ -74,6 +74,8 @@ private:
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// used for errors // used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status); llama_kv_cache_unified_iswa_context(llama_memory_status status);
@ -90,8 +92,8 @@ public:
// used to create a batch processing context from a batch // used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, slot_info_vec_t sinfos_base,
std::vector<uint32_t> heads_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context(); virtual ~llama_kv_cache_unified_iswa_context();

View File

@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
if (!supports_set_rows) {
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
}
} }
void llama_kv_cache_unified::clear(bool data) { void llama_kv_cache_unified::clear(bool data) {
@ -353,13 +360,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads = prepare(ubatches); auto sinfos = prepare(ubatches);
if (heads.empty()) { if (sinfos.empty()) {
break; break;
} }
return std::make_unique<llama_kv_cache_unified_context>( return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches)); this, std::move(sinfos), std::move(ubatches));
} while (false); } while (false);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@ -402,12 +409,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo)); return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
} }
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::ubatch_heads res; llama_kv_cache_unified::slot_info_vec_t res;
struct state { struct state {
uint32_t head_old; // old position of the head, before placing the ubatch uint32_t head_old; // old position of the head, before placing the ubatch
uint32_t head_new; // new position of the head, after placing the ubatch
slot_info sinfo; // slot info for the ubatch
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
}; };
@ -418,26 +426,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
bool success = true; bool success = true;
for (const auto & ubatch : ubatches) { for (const auto & ubatch : ubatches) {
// non-continuous slots require support for ggml_set_rows()
const bool cont = supports_set_rows ? false : true;
// only find a suitable slot for the ubatch. don't modify the cells yet // only find a suitable slot for the ubatch. don't modify the cells yet
const int32_t head_new = find_slot(ubatch); const auto sinfo_new = find_slot(ubatch, cont);
if (head_new < 0) { if (sinfo_new.empty()) {
success = false; success = false;
break; break;
} }
// remeber the position that we found // remeber the position that we found
res.push_back(head_new); res.push_back(sinfo_new);
// store the old state of the cells in the recovery stack // store the old state of the cells in the recovery stack
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
// now emplace the ubatch // now emplace the ubatch
apply_ubatch(head_new, ubatch); apply_ubatch(sinfo_new, ubatch);
} }
// iterate backwards and restore the cells to their original state // iterate backwards and restore the cells to their original state
for (auto it = states.rbegin(); it != states.rend(); ++it) { for (auto it = states.rbegin(); it != states.rend(); ++it) {
cells.set(it->head_new, it->cells); cells.set(it->sinfo.idxs, it->cells);
head = it->head_old; head = it->head_old;
} }
@ -539,7 +550,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
return updated; return updated;
} }
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head; uint32_t head_cur = this->head;
@ -552,7 +563,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (n_tokens > cells.size()) { if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return -1; return { };
} }
if (debug > 0) { if (debug > 0) {
@ -615,15 +626,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
uint32_t n_tested = 0; uint32_t n_tested = 0;
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
// for non-continuous slots, we test the tokens one by one
const uint32_t n_test = cont ? n_tokens : 1;
slot_info res;
auto & idxs = res.idxs;
idxs.reserve(n_tokens);
while (true) { while (true) {
if (head_cur + n_tokens > cells.size()) { if (head_cur + n_test > cells.size()) {
n_tested += cells.size() - head_cur; n_tested += cells.size() - head_cur;
head_cur = 0; head_cur = 0;
continue; continue;
} }
bool found = true; for (uint32_t i = 0; i < n_test; i++) {
for (uint32_t i = 0; i < n_tokens; i++) { const auto idx = head_cur;
//const llama_pos pos = ubatch.pos[i]; //const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0]; //const llama_seq_id seq_id = ubatch.seq_id[i][0];
@ -633,19 +655,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
// - (disabled) mask causally, if the sequence is the same as the one we are inserting // - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache // - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos // always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i); bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(head_cur + i) == 1) { if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i); const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask // (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand // note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(head_cur + i, seq_id)) { //if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos; // can_use = pos_cell >= pos;
//} //}
if (!can_use) { if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask // SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@ -654,28 +676,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
} }
} }
if (!can_use) { head_cur++;
found = false; n_tested++;
head_cur += i + 1;
n_tested += i + 1; if (can_use) {
idxs.push_back(idx);
} else {
break; break;
} }
} }
if (found) { if (idxs.size() == n_tokens) {
break; break;
} }
if (cont) {
idxs.clear();
}
if (n_tested >= cells.size()) { if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return -1; return { };
} }
} }
return head_cur; // we didn't find a suitable slot - return empty result
if (idxs.size() < n_tokens) {
res.clear();
} }
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { return res;
}
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// keep track of the max sequence position that we would overwrite with this ubatch // keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty // for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@ -683,22 +716,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1; seq_pos_max_rm[s] = -1;
} }
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { assert(ubatch.n_tokens == sinfo.idxs.size());
if (!cells.is_empty(head_cur + i)) {
assert(cells.seq_count(head_cur + i) == 1);
const llama_seq_id seq_id = cells.seq_get(head_cur + i); for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const llama_pos pos = cells.pos_get(head_cur + i); const auto idx = sinfo.idxs.at(i);
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + i); cells.rm(idx);
} }
cells.pos_set(head_cur + i, ubatch.pos[i]); cells.pos_set(idx, ubatch.pos[i]);
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]); cells.seq_add(idx, ubatch.seq_id[i][s]);
} }
} }
@ -719,7 +756,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
} }
// move the head at the end of the slot // move the head at the end of the slot
head = head_cur + ubatch.n_tokens; head = sinfo.idxs.back() + 1;
} }
bool llama_kv_cache_unified::get_can_shift() const { bool llama_kv_cache_unified::get_can_shift() const {
@ -772,47 +809,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
0); 0);
} }
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il); const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k; auto * k = layers[ikv].k;
const int64_t n_embd_k_gqa = k->ne[0];
const int64_t n_tokens = k_cur->ne[2]; const int64_t n_tokens = k_cur->ne[2];
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
if (k_idxs && supports_set_rows) {
return ggml_set_rows(ctx, k, k_cur, k_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
ggml_tensor * k_view = ggml_view_1d(ctx, k, ggml_tensor * k_view = ggml_view_1d(ctx, k,
n_tokens*hparams.n_embd_k_gqa(il), n_tokens*n_embd_k_gqa,
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
return ggml_cpy(ctx, k_cur, k_view); return ggml_cpy(ctx, k_cur, k_view);
} }
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il); const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v; auto * v = layers[ikv].v;
const int64_t n_embd_v_gqa = v->ne[0];
const int64_t n_tokens = v_cur->ne[2]; const int64_t n_tokens = v_cur->ne[2];
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
if (v_idxs && supports_set_rows) {
if (!v_trans) {
return ggml_set_rows(ctx, v, v_cur, v_idxs);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
// note: the V cache is transposed when not using flash attention
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
// note: we can be more explicit here at the cost of extra cont
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
//v_cur = ggml_transpose(ctx, v_cur);
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
// we broadcast the KV indices n_embd_v_gqa times
// v [1, n_kv, n_embd_v_gqa]
// v_cur [1, n_tokens, n_embd_v_gqa]
// v_idxs [n_tokens, 1, 1]
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
ggml_tensor * v_view = nullptr; ggml_tensor * v_view = nullptr;
if (!v_trans) { if (!v_trans) {
v_view = ggml_view_1d(ctx, v, v_view = ggml_view_1d(ctx, v,
n_tokens*hparams.n_embd_v_gqa(il), n_tokens*n_embd_v_gqa,
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
} else { } else {
// note: the V cache is transposed when not using flash attention
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
(v->ne[1])*ggml_element_size(v),
(head_cur)*ggml_element_size(v));
v_cur = ggml_transpose(ctx, v_cur); v_cur = ggml_transpose(ctx, v_cur);
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
(v->ne[1] )*ggml_element_size(v),
(sinfo.head())*ggml_element_size(v));
} }
return ggml_cpy(ctx, v_cur, v_view); return ggml_cpy(ctx, v_cur, v_view);
} }
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_set_input(k_idxs);
return k_idxs;
}
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_set_input(v_idxs);
return v_idxs;
}
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens; const uint32_t n_tokens = ubatch->n_tokens;
@ -1552,13 +1675,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
ubatch.seq_id[i] = &dest_seq_id; ubatch.seq_id[i] = &dest_seq_id;
} }
const auto head_cur = find_slot(ubatch); const auto sinfo = find_slot(ubatch, true);
if (head_cur < 0) { if (sinfo.empty()) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
apply_ubatch(head_cur, ubatch); apply_ubatch(sinfo, ubatch);
const auto head_cur = sinfo.head();
// keep the head at the old position because we will read the KV data into it in state_read_data() // keep the head at the old position because we will read the KV data into it in state_read_data()
head = head_cur; head = head_cur;
@ -1744,7 +1869,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size(); n_kv = kv->get_size();
head = 0;
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
sinfos.resize(1);
sinfos[0].idxs.resize(1);
sinfos[0].idxs[0] = 0;
} }
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@ -1759,8 +1888,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_kv_cache_unified::ubatch_heads heads, llama_kv_cache_unified::slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
} }
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@ -1768,7 +1897,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_context::next() { bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_cur >= ubatches.size()) {
return false; return false;
} }
@ -1785,10 +1914,9 @@ bool llama_kv_cache_unified_context::apply() {
return true; return true;
} }
kv->apply_ubatch(heads[i_next], ubatches[i_next]); kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
n_kv = kv->get_n_kv(); n_kv = kv->get_n_kv();
head = heads[i_next];
return true; return true;
} }
@ -1800,7 +1928,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_cur];
} }
uint32_t llama_kv_cache_unified_context::get_n_kv() const { uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@ -1815,18 +1943,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
return kv->get_v(ctx, il, n_kv); return kv->get_v(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head); return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
} }
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head); return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_v_idxs(ctx, ubatch);
} }
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst); kv->set_input_k_shift(dst);
} }
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn); kv->set_input_kq_mask(dst, ubatch, causal_attn);
} }

View File

@ -24,8 +24,6 @@ public:
// this callback is used to filter out layers that should not be included in the cache // this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>; using layer_filter_cb = std::function<bool(int32_t il)>;
using ubatch_heads = std::vector<uint32_t>;
struct defrag_info { struct defrag_info {
bool empty() const { bool empty() const {
return ids.empty(); return ids.empty();
@ -37,6 +35,32 @@ public:
std::vector<uint32_t> ids; std::vector<uint32_t> ids;
}; };
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs;
uint32_t head() const {
return idxs.at(0);
}
bool empty() const {
return idxs.empty();
}
void clear() {
idxs.clear();
}
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
};
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified( llama_kv_cache_unified(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, layer_filter_cb && filter,
@ -102,30 +126,37 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
// //
// preparation API // preparation API
// //
// find places for the provided ubatches in the cache, returns the head locations // find places for the provided ubatches in the cache, returns the slot infos
// return empty vector on failure // return empty vector on failure
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches); slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch // find a slot of kv cells that can hold the ubatch
// return -1 on failure to find a contiguous slot of kv cells // if cont == true, then the slot must be continuous
int32_t find_slot(const llama_ubatch & ubatch) const; // return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
// //
// set_input API // input API
// //
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const; void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -157,8 +188,13 @@ private:
// SWA // SWA
const uint32_t n_swa = 0; const uint32_t n_swa = 0;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0; int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
int supports_set_rows = false;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
@ -211,7 +247,7 @@ private:
class llama_kv_cache_unified_context : public llama_memory_context_i { class llama_kv_cache_unified_context : public llama_memory_context_i {
public: public:
// some shorthands // some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info; using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors // used for errors
@ -231,7 +267,7 @@ public:
// used to create a batch procesing context from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_context( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
ubatch_heads heads, slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context(); virtual ~llama_kv_cache_unified_context();
@ -257,11 +293,16 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const; void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -283,10 +324,10 @@ private:
// batch processing context // batch processing context
// //
// the index of the next ubatch to process // the index of the cur ubatch to process
size_t i_next = 0; size_t i_cur = 0;
ubatch_heads heads; slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
@ -297,7 +338,4 @@ private:
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears // as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv; int32_t n_kv;
// the beginning of the current slot in which the ubatch will be inserted
int32_t head;
}; };

View File

@ -105,10 +105,30 @@ public:
res.resize(n); res.resize(n);
for (uint32_t j = 0; j < n; ++j) { for (uint32_t j = 0; j < n; ++j) {
res.pos[j] = pos[i + j]; const auto idx = i + j;
res.seq[j] = seq[i + j];
assert(shift[i + j] == 0); res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
res.resize(idxs.size());
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
} }
return res; return res;
@ -119,26 +139,58 @@ public:
assert(i + other.pos.size() <= pos.size()); assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) { for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) { const auto idx = i + j;
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(i + j); used.insert(i + j);
} }
if (pos[i + j] != -1 && other.pos[j] == -1) { if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(i + j); used.erase(i + j);
} }
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_rm(i + j); seq_pos_rm(i + j);
} }
pos[i + j] = other.pos[j]; pos[idx] = other.pos[j];
seq[i + j] = other.seq[j]; seq[idx] = other.seq[j];
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_add(i + j); seq_pos_add(i + j);
} }
assert(shift[i + j] == 0); assert(shift[idx] == 0);
}
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = idxs[j];
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(idx);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(idx);
}
if (pos[idx] != -1) {
seq_pos_rm(idx);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(idx);
}
assert(shift[idx] == 0);
} }
} }

View File

@ -195,11 +195,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }

View File

@ -92,6 +92,8 @@ private:
class llama_memory_hybrid_context : public llama_memory_context_i { class llama_memory_hybrid_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// init failure // init failure
explicit llama_memory_hybrid_context(llama_memory_status status); explicit llama_memory_hybrid_context(llama_memory_status status);
@ -107,7 +109,7 @@ public:
// init success // init success
llama_memory_hybrid_context( llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_context() = default; ~llama_memory_hybrid_context() = default;

View File

@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
if (v & 1) { if (v & 1) {
auto ne = ne_a; ne[0] *= 3; auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data()); a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a"); ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a"); ggml_set_name(a, "view_of_a");
b = ggml_new_tensor(ctx, type, 4, ne.data()); b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(b);
ggml_set_name(b, "b"); ggml_set_name(b, "b");
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0); b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
ggml_set_name(a, "view_of_b"); ggml_set_name(a, "view_of_b");
} else { } else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data()); a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(a);
ggml_set_name(a, "a"); ggml_set_name(a, "a");
b = ggml_new_tensor(ctx, type, 4, ne_a.data()); b = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(b);
ggml_set_name(b, "b"); ggml_set_name(b, "b");
} }
@ -3637,7 +3641,7 @@ struct test_flash_attn_ext : public test_case {
ggml_tensor * m = nullptr; ggml_tensor * m = nullptr;
if (mask) { if (mask) {
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1); m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
ggml_set_name(m, "m"); ggml_set_name(m, "m");
} }
@ -4751,7 +4755,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
if (ne0 <= 32 && ne1 <= 32) { if (ne0 <= 32 && ne1 <= 32) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
} }
} }
@ -4932,6 +4936,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

View File

@ -1,7 +1,3 @@
#include "llama.h" #include "llama.h"
#ifdef GGML_USE_KOMPUTE
#include "ggml-kompute.h"
#endif
int main(void) {} int main(void) {}