Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
commit
4682e21c46
|
|
@ -40,7 +40,7 @@ body:
|
||||||
attributes:
|
attributes:
|
||||||
label: GGML backends
|
label: GGML backends
|
||||||
description: Which GGML backends do you know to be affected?
|
description: Which GGML backends do you know to be affected?
|
||||||
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
|
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
|
||||||
multiple: true
|
multiple: true
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ body:
|
||||||
attributes:
|
attributes:
|
||||||
label: GGML backends
|
label: GGML backends
|
||||||
description: Which GGML backends do you know to be affected?
|
description: Which GGML backends do you know to be affected?
|
||||||
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
|
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
|
||||||
multiple: true
|
multiple: true
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,4 @@
|
||||||
# https://github.com/actions/labeler
|
# https://github.com/actions/labeler
|
||||||
Kompute:
|
|
||||||
- changed-files:
|
|
||||||
- any-glob-to-any-file:
|
|
||||||
- ggml/include/ggml-kompute.h
|
|
||||||
- ggml/src/ggml-kompute/**
|
|
||||||
- README-kompute.md
|
|
||||||
Apple Metal:
|
Apple Metal:
|
||||||
- changed-files:
|
- changed-files:
|
||||||
- any-glob-to-any-file:
|
- any-glob-to-any-file:
|
||||||
|
|
|
||||||
|
|
@ -740,9 +740,6 @@ jobs:
|
||||||
- build: 'llvm-arm64-opencl-adreno'
|
- build: 'llvm-arm64-opencl-adreno'
|
||||||
arch: 'arm64'
|
arch: 'arm64'
|
||||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
|
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
|
||||||
# - build: 'kompute-x64'
|
|
||||||
# arch: 'x64'
|
|
||||||
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON'
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
|
|
@ -756,12 +753,6 @@ jobs:
|
||||||
variant: ccache
|
variant: ccache
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Clone Kompute submodule
|
|
||||||
id: clone_kompute
|
|
||||||
if: ${{ matrix.build == 'kompute-x64' }}
|
|
||||||
run: |
|
|
||||||
git submodule update --init ggml/src/ggml-kompute/kompute
|
|
||||||
|
|
||||||
- name: Download OpenBLAS
|
- name: Download OpenBLAS
|
||||||
id: get_openblas
|
id: get_openblas
|
||||||
if: ${{ matrix.build == 'openblas-x64' }}
|
if: ${{ matrix.build == 'openblas-x64' }}
|
||||||
|
|
@ -777,7 +768,7 @@ jobs:
|
||||||
|
|
||||||
- name: Install Vulkan SDK
|
- name: Install Vulkan SDK
|
||||||
id: get_vulkan
|
id: get_vulkan
|
||||||
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
|
if: ${{ matrix.build == 'vulkan-x64' }}
|
||||||
run: |
|
run: |
|
||||||
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
|
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
|
||||||
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
|
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
|
||||||
|
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
[submodule "kompute"]
|
|
||||||
path = ggml/src/ggml-kompute/kompute
|
|
||||||
url = https://github.com/nomic-ai/kompute.git
|
|
||||||
|
|
@ -120,7 +120,6 @@ endfunction()
|
||||||
|
|
||||||
llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA)
|
llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA)
|
||||||
llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA)
|
llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA)
|
||||||
llama_option_depr(WARNING LLAMA_KOMPUTE GGML_KOMPUTE)
|
|
||||||
llama_option_depr(WARNING LLAMA_METAL GGML_METAL)
|
llama_option_depr(WARNING LLAMA_METAL GGML_METAL)
|
||||||
llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY)
|
llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY)
|
||||||
llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE)
|
llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE)
|
||||||
|
|
|
||||||
|
|
@ -4408,9 +4408,6 @@ class Gemma3NModel(Gemma3Model):
|
||||||
]
|
]
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
with open(self.dir_model / "chat_template.jinja") as f:
|
|
||||||
# quick hack to make sure chat template is added
|
|
||||||
self.gguf_writer.add_chat_template(f.read())
|
|
||||||
super().set_vocab()
|
super().set_vocab()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,6 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
|
||||||
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
|
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
|
||||||
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
|
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
|
||||||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||||
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
|
|
||||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||||
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
||||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||||
|
|
@ -266,7 +265,6 @@ set(GGML_PUBLIC_HEADERS
|
||||||
include/ggml-cann.h
|
include/ggml-cann.h
|
||||||
include/ggml-cpp.h
|
include/ggml-cpp.h
|
||||||
include/ggml-cuda.h
|
include/ggml-cuda.h
|
||||||
include/ggml-kompute.h
|
|
||||||
include/ggml-opt.h
|
include/ggml-opt.h
|
||||||
include/ggml-metal.h
|
include/ggml-metal.h
|
||||||
include/ggml-rpc.h
|
include/ggml-rpc.h
|
||||||
|
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#include <stdbool.h>
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define GGML_KOMPUTE_MAX_DEVICES 16
|
|
||||||
|
|
||||||
struct ggml_vk_device {
|
|
||||||
int index;
|
|
||||||
int type; // same as VkPhysicalDeviceType
|
|
||||||
size_t heapSize;
|
|
||||||
const char * name;
|
|
||||||
const char * vendor;
|
|
||||||
int subgroupSize;
|
|
||||||
uint64_t bufferAlignment;
|
|
||||||
uint64_t maxAlloc;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
|
|
||||||
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
|
|
||||||
bool ggml_vk_has_vulkan(void);
|
|
||||||
bool ggml_vk_has_device(void);
|
|
||||||
struct ggml_vk_device ggml_vk_current_device(void);
|
|
||||||
|
|
||||||
//
|
|
||||||
// backend API
|
|
||||||
//
|
|
||||||
|
|
||||||
// forward declaration
|
|
||||||
typedef struct ggml_backend * ggml_backend_t;
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
@ -1986,12 +1986,13 @@ extern "C" {
|
||||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||||
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
||||||
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
|
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
|
||||||
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||||
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
|
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
|
||||||
//
|
//
|
||||||
// broadcast:
|
// broadcast:
|
||||||
// n_head % n_head_kv == 0
|
// n_head % n_head_kv == 0
|
||||||
// ne3 % ne32 == 0
|
// n_head % ne32 == 0
|
||||||
|
// ne3 % ne33 == 0
|
||||||
//
|
//
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
|
||||||
|
|
@ -365,7 +365,6 @@ ggml_add_backend(BLAS)
|
||||||
ggml_add_backend(CANN)
|
ggml_add_backend(CANN)
|
||||||
ggml_add_backend(CUDA)
|
ggml_add_backend(CUDA)
|
||||||
ggml_add_backend(HIP)
|
ggml_add_backend(HIP)
|
||||||
ggml_add_backend(Kompute)
|
|
||||||
ggml_add_backend(METAL)
|
ggml_add_backend(METAL)
|
||||||
ggml_add_backend(MUSA)
|
ggml_add_backend(MUSA)
|
||||||
ggml_add_backend(RPC)
|
ggml_add_backend(RPC)
|
||||||
|
|
|
||||||
|
|
@ -61,10 +61,6 @@
|
||||||
#include "ggml-cann.h"
|
#include "ggml-cann.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_KOMPUTE
|
|
||||||
#include "ggml-kompute.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
|
|
@ -189,9 +185,6 @@ struct ggml_backend_registry {
|
||||||
#ifdef GGML_USE_RPC
|
#ifdef GGML_USE_RPC
|
||||||
register_backend(ggml_backend_rpc_reg());
|
register_backend(ggml_backend_rpc_reg());
|
||||||
#endif
|
#endif
|
||||||
#ifdef GGML_USE_KOMPUTE
|
|
||||||
register_backend(ggml_backend_kompute_reg());
|
|
||||||
#endif
|
|
||||||
#ifdef GGML_USE_CPU
|
#ifdef GGML_USE_CPU
|
||||||
register_backend(ggml_backend_cpu_reg());
|
register_backend(ggml_backend_cpu_reg());
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -575,7 +568,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||||
ggml_backend_load_best("cann", silent, dir_path);
|
ggml_backend_load_best("cann", silent, dir_path);
|
||||||
ggml_backend_load_best("cuda", silent, dir_path);
|
ggml_backend_load_best("cuda", silent, dir_path);
|
||||||
ggml_backend_load_best("hip", silent, dir_path);
|
ggml_backend_load_best("hip", silent, dir_path);
|
||||||
ggml_backend_load_best("kompute", silent, dir_path);
|
|
||||||
ggml_backend_load_best("metal", silent, dir_path);
|
ggml_backend_load_best("metal", silent, dir_path);
|
||||||
ggml_backend_load_best("rpc", silent, dir_path);
|
ggml_backend_load_best("rpc", silent, dir_path);
|
||||||
ggml_backend_load_best("sycl", silent, dir_path);
|
ggml_backend_load_best("sycl", silent, dir_path);
|
||||||
|
|
|
||||||
|
|
@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SET_ROWS:
|
||||||
|
{
|
||||||
|
// TODO: add support
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_CPY: {
|
case GGML_OP_CPY: {
|
||||||
ggml_tensor *src = op->src[0];
|
ggml_tensor *src = op->src[0];
|
||||||
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
|
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
|
||||||
|
|
|
||||||
|
|
@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
memset(VKQ32, 0, DV*sizeof(float));
|
memset(VKQ32, 0, DV*sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
||||||
|
|
||||||
// k indices
|
// k indices
|
||||||
const int ik3 = iq3 / rk3;
|
const int ik3 = iq3 / rk3;
|
||||||
|
|
|
||||||
|
|
@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) {
|
||||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
|
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||||
|
do { \
|
||||||
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
||||||
|
const int id = ggml_cuda_get_device(); \
|
||||||
|
if (!shared_memory_limit_raised[id]) { \
|
||||||
|
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||||
|
shared_memory_limit_raised[id] = true; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
#else
|
||||||
|
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
||||||
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
|
|
@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||||
|
|
||||||
if (nbytes_shared <= smpbo) {
|
if (nbytes_shared <= smpbo) {
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
||||||
if (!shared_memory_limit_raised[id]) {
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
||||||
shared_memory_limit_raised[id] = true;
|
|
||||||
}
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
||||||
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||||
} else {
|
} else {
|
||||||
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||||
|
|
@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
if (nbytes_shared <= smpbo) {
|
if (nbytes_shared <= smpbo) {
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
||||||
if (!shared_memory_limit_raised[id]) {
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
||||||
shared_memory_limit_raised[id] = true;
|
|
||||||
}
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
||||||
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||||
} else {
|
} else {
|
||||||
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||||
|
|
|
||||||
|
|
@ -3390,7 +3390,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// TODO: support broadcast
|
// TODO: support broadcast
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||||
|
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||||
if (op->src[0]->ne[3] != 1) {
|
if (op->src[0]->ne[3] != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||||
|
|
||||||
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
|
||||||
if (!shared_memory_limit_raised[id]) {
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
||||||
shared_memory_limit_raised[id] = true;
|
|
||||||
}
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
||||||
|
|
||||||
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
||||||
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "softmax.cuh"
|
#include "softmax.cuh"
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __device__ __forceinline__ float t2f32(T val) {
|
static __device__ __forceinline__ float t2f32(T val) {
|
||||||
|
|
@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int... Ns, typename T>
|
||||||
|
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
||||||
|
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||||
|
{
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
|
auto launch_kernel = [=](auto I) -> bool {
|
||||||
|
constexpr int ncols = decltype(I)::value;
|
||||||
|
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||||
|
|
||||||
|
if (p.ncols == ncols) {
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||||
|
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(x, mask, dst, p);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// unary fold over launch_kernel
|
||||||
|
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//default case
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||||
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
|
|
@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
|
||||||
|
|
||||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
const int id = ggml_cuda_get_device();
|
||||||
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
switch (ncols_x) {
|
|
||||||
case 32:
|
|
||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
if (nbytes_shared <= smpbo) {
|
||||||
(x, mask, dst, params);
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
case 256:
|
|
||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
case 512:
|
|
||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
case 1024:
|
|
||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
case 2048:
|
|
||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
case 4096:
|
|
||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, params);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
||||||
|
|
|
||||||
|
|
@ -1,166 +0,0 @@
|
||||||
|
|
||||||
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
|
||||||
find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
|
|
||||||
|
|
||||||
if (NOT glslc_executable)
|
|
||||||
message(FATAL_ERROR "glslc not found")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
ggml_add_backend_library(ggml-kompute
|
|
||||||
ggml-kompute.cpp
|
|
||||||
../../include/ggml-kompute.h
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
|
|
||||||
target_include_directories(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
|
||||||
|
|
||||||
add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
|
|
||||||
|
|
||||||
function(compile_shader)
|
|
||||||
set(options)
|
|
||||||
set(oneValueArgs)
|
|
||||||
set(multiValueArgs SOURCES)
|
|
||||||
cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
|
||||||
foreach(source ${compile_shader_SOURCES})
|
|
||||||
get_filename_component(filename ${source} NAME)
|
|
||||||
set(spv_file ${filename}.spv)
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT ${spv_file}
|
|
||||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
|
|
||||||
COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
|
|
||||||
COMMENT "Compiling ${source} to ${spv_file}"
|
|
||||||
)
|
|
||||||
|
|
||||||
get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
|
|
||||||
set(FILE_NAME "shader${RAW_FILE_NAME}")
|
|
||||||
string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
|
|
||||||
string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
|
|
||||||
string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
|
|
||||||
set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
|
|
||||||
message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
|
|
||||||
if(CMAKE_GENERATOR MATCHES "Visual Studio")
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
DEPENDS ${spv_file} xxd
|
|
||||||
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
|
|
||||||
)
|
|
||||||
else()
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
|
||||||
DEPENDS ${spv_file} xxd
|
|
||||||
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
|
|
||||||
message(STATUS "Kompute found")
|
|
||||||
set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
|
|
||||||
add_subdirectory(kompute)
|
|
||||||
|
|
||||||
# Compile our shaders
|
|
||||||
compile_shader(SOURCES
|
|
||||||
kompute-shaders/op_scale.comp
|
|
||||||
kompute-shaders/op_scale_8.comp
|
|
||||||
kompute-shaders/op_add.comp
|
|
||||||
kompute-shaders/op_addrow.comp
|
|
||||||
kompute-shaders/op_mul.comp
|
|
||||||
kompute-shaders/op_silu.comp
|
|
||||||
kompute-shaders/op_relu.comp
|
|
||||||
kompute-shaders/op_gelu.comp
|
|
||||||
kompute-shaders/op_softmax.comp
|
|
||||||
kompute-shaders/op_norm.comp
|
|
||||||
kompute-shaders/op_rmsnorm.comp
|
|
||||||
kompute-shaders/op_diagmask.comp
|
|
||||||
kompute-shaders/op_mul_mat_mat_f32.comp
|
|
||||||
kompute-shaders/op_mul_mat_f16.comp
|
|
||||||
kompute-shaders/op_mul_mat_q8_0.comp
|
|
||||||
kompute-shaders/op_mul_mat_q4_0.comp
|
|
||||||
kompute-shaders/op_mul_mat_q4_1.comp
|
|
||||||
kompute-shaders/op_mul_mat_q4_k.comp
|
|
||||||
kompute-shaders/op_mul_mat_q6_k.comp
|
|
||||||
kompute-shaders/op_getrows_f32.comp
|
|
||||||
kompute-shaders/op_getrows_f16.comp
|
|
||||||
kompute-shaders/op_getrows_q4_0.comp
|
|
||||||
kompute-shaders/op_getrows_q4_1.comp
|
|
||||||
kompute-shaders/op_getrows_q6_k.comp
|
|
||||||
kompute-shaders/op_rope_norm_f16.comp
|
|
||||||
kompute-shaders/op_rope_norm_f32.comp
|
|
||||||
kompute-shaders/op_rope_neox_f16.comp
|
|
||||||
kompute-shaders/op_rope_neox_f32.comp
|
|
||||||
kompute-shaders/op_cpy_f16_f16.comp
|
|
||||||
kompute-shaders/op_cpy_f16_f32.comp
|
|
||||||
kompute-shaders/op_cpy_f32_f16.comp
|
|
||||||
kompute-shaders/op_cpy_f32_f32.comp
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a custom target for our generated shaders
|
|
||||||
add_custom_target(generated_shaders DEPENDS
|
|
||||||
shaderop_scale.h
|
|
||||||
shaderop_scale_8.h
|
|
||||||
shaderop_add.h
|
|
||||||
shaderop_addrow.h
|
|
||||||
shaderop_mul.h
|
|
||||||
shaderop_silu.h
|
|
||||||
shaderop_relu.h
|
|
||||||
shaderop_gelu.h
|
|
||||||
shaderop_softmax.h
|
|
||||||
shaderop_norm.h
|
|
||||||
shaderop_rmsnorm.h
|
|
||||||
shaderop_diagmask.h
|
|
||||||
shaderop_mul_mat_mat_f32.h
|
|
||||||
shaderop_mul_mat_f16.h
|
|
||||||
shaderop_mul_mat_q8_0.h
|
|
||||||
shaderop_mul_mat_q4_0.h
|
|
||||||
shaderop_mul_mat_q4_1.h
|
|
||||||
shaderop_mul_mat_q4_k.h
|
|
||||||
shaderop_mul_mat_q6_k.h
|
|
||||||
shaderop_getrows_f32.h
|
|
||||||
shaderop_getrows_f16.h
|
|
||||||
shaderop_getrows_q4_0.h
|
|
||||||
shaderop_getrows_q4_1.h
|
|
||||||
shaderop_getrows_q6_k.h
|
|
||||||
shaderop_rope_norm_f16.h
|
|
||||||
shaderop_rope_norm_f32.h
|
|
||||||
shaderop_rope_neox_f16.h
|
|
||||||
shaderop_rope_neox_f32.h
|
|
||||||
shaderop_cpy_f16_f16.h
|
|
||||||
shaderop_cpy_f16_f32.h
|
|
||||||
shaderop_cpy_f32_f16.h
|
|
||||||
shaderop_cpy_f32_f32.h
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a custom command that depends on the generated_shaders
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
|
|
||||||
DEPENDS generated_shaders
|
|
||||||
COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the stamp to the main sources to ensure dependency tracking
|
|
||||||
target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
|
|
||||||
else()
|
|
||||||
message(WARNING "Kompute not found")
|
|
||||||
endif()
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 4565194ed7c32d1d2efa32ceab4d3c6cae006306
|
|
||||||
|
|
@ -1,112 +0,0 @@
|
||||||
#extension GL_EXT_shader_16bit_storage: require
|
|
||||||
#extension GL_EXT_shader_8bit_storage: require
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
|
|
||||||
#extension GL_EXT_control_flow_attributes: enable
|
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
|
||||||
#extension GL_EXT_debug_printf : enable
|
|
||||||
|
|
||||||
#define QK4_0 32
|
|
||||||
#define QK4_1 32
|
|
||||||
|
|
||||||
#define GELU_COEF_A 0.044715
|
|
||||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
|
|
||||||
#define TWOPI_F 6.283185307179586f
|
|
||||||
|
|
||||||
#define QK_K 256
|
|
||||||
#define K_SCALE_SIZE 12
|
|
||||||
|
|
||||||
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
|
|
||||||
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
|
|
||||||
#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
|
|
||||||
#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
|
|
||||||
|
|
||||||
#define sizeof_block_q4_0 0x12
|
|
||||||
struct block_q4_0 {
|
|
||||||
float16_t d;
|
|
||||||
uint8_t qs[QK4_0 / 2];
|
|
||||||
};
|
|
||||||
mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
|
|
||||||
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
|
|
||||||
const float d2 = d1 / 256.f;
|
|
||||||
const float md = -8.f * xb.d;
|
|
||||||
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
|
|
||||||
const uint16_t mask1 = mask0 << 8;
|
|
||||||
|
|
||||||
mat4 reg;
|
|
||||||
for (int i=0;i<8;i++) {
|
|
||||||
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
|
|
||||||
reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
|
|
||||||
reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
|
|
||||||
}
|
|
||||||
return reg;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define sizeof_block_q4_1 0x14
|
|
||||||
struct block_q4_1 {
|
|
||||||
float16_t d;
|
|
||||||
float16_t m;
|
|
||||||
uint8_t qs[QK4_1 / 2];
|
|
||||||
};
|
|
||||||
mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
|
|
||||||
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
|
|
||||||
const float d2 = d1 / 256.f;
|
|
||||||
const float m = xb.m;
|
|
||||||
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
|
|
||||||
const uint16_t mask1 = mask0 << 8;
|
|
||||||
|
|
||||||
mat4 reg;
|
|
||||||
for (int i=0;i<8;i++) {
|
|
||||||
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
|
|
||||||
reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
|
|
||||||
reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
|
|
||||||
}
|
|
||||||
return reg;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define sizeof_block_q4_k 144
|
|
||||||
struct block_q4_k {
|
|
||||||
float16_t d;
|
|
||||||
float16_t dmin;
|
|
||||||
uint8_t scales[K_SCALE_SIZE];
|
|
||||||
uint8_t qs[QK_K/2];
|
|
||||||
};
|
|
||||||
|
|
||||||
#define sizeof_block_q6_k 210
|
|
||||||
struct block_q6_k {
|
|
||||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
|
||||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
|
||||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
|
||||||
float16_t d; // super-block scale
|
|
||||||
};
|
|
||||||
mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
|
|
||||||
const float16_t d_all = xb.d;
|
|
||||||
|
|
||||||
const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
|
||||||
const uint qhIndex = 32*(il/8) + 16*(il&1);
|
|
||||||
float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
|
|
||||||
il = (il/2) & 3;
|
|
||||||
|
|
||||||
const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
|
|
||||||
const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F);
|
|
||||||
const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f);
|
|
||||||
const float16_t ml = float16_t(d_all * sc * 32.f);
|
|
||||||
const float16_t dl = float16_t(d_all * sc * coef);
|
|
||||||
mat4 reg;
|
|
||||||
for (int i = 0; i < 16; ++i) {
|
|
||||||
const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
|
|
||||||
: ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
|
|
||||||
reg[i/4][i%4] = dl * q - ml;
|
|
||||||
}
|
|
||||||
return reg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#define QK8_0 32
|
|
||||||
// struct block_q8_0 {
|
|
||||||
// float16_t d; // delta
|
|
||||||
// int8_t qs[QK8_0]; // quants
|
|
||||||
// };
|
|
||||||
#define sizeof_block_q8_0 34
|
|
||||||
|
|
@ -1,58 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1024) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb00;
|
|
||||||
int nb01;
|
|
||||||
int nb02;
|
|
||||||
int nb03;
|
|
||||||
int ne10;
|
|
||||||
int ne11;
|
|
||||||
int ne12;
|
|
||||||
int ne13;
|
|
||||||
int nb10;
|
|
||||||
int nb11;
|
|
||||||
int nb12;
|
|
||||||
int nb13;
|
|
||||||
int ne0;
|
|
||||||
int nb0;
|
|
||||||
int nb1;
|
|
||||||
int nb2;
|
|
||||||
int nb3;
|
|
||||||
//int offs; // TODO: needed for GGML_OP_ACC, see metal code
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
// general-purpose kernel for addition of two tensors
|
|
||||||
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
|
||||||
// cons: not very efficient
|
|
||||||
void main() {
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const uint i13 = i03 % pcs.ne13;
|
|
||||||
const uint i12 = i02 % pcs.ne12;
|
|
||||||
const uint i11 = i01 % pcs.ne11;
|
|
||||||
|
|
||||||
int offs = 0; // TMP (see above)
|
|
||||||
|
|
||||||
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
|
|
||||||
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4);
|
|
||||||
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4);
|
|
||||||
|
|
||||||
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
|
|
||||||
const uint i10 = i0 % pcs.ne10;
|
|
||||||
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
uint row;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
|
||||||
|
|
||||||
for (uint x = 0; x < 4; x++) {
|
|
||||||
const uint i = baseIndex + x;
|
|
||||||
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define IN_TYPE float16_t
|
|
||||||
#define IN_TYPE_SIZE 2
|
|
||||||
#define OUT_TYPE float16_t
|
|
||||||
#define OUT_TYPE_SIZE 2
|
|
||||||
|
|
||||||
layout(local_size_x = 1024) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
|
|
||||||
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
int ne2;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
|
|
||||||
|
|
||||||
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
|
|
||||||
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
|
|
||||||
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
|
|
||||||
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
|
|
||||||
|
|
||||||
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
|
|
||||||
out_[dst_data+i00] = OUT_TYPE(in_[src]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define IN_TYPE float16_t
|
|
||||||
#define IN_TYPE_SIZE 2
|
|
||||||
#define OUT_TYPE float
|
|
||||||
#define OUT_TYPE_SIZE 4
|
|
||||||
|
|
||||||
layout(local_size_x = 1024) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
|
|
||||||
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
int ne2;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
|
|
||||||
|
|
||||||
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
|
|
||||||
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
|
|
||||||
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
|
|
||||||
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
|
|
||||||
|
|
||||||
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
|
|
||||||
out_[dst_data+i00] = OUT_TYPE(in_[src]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define IN_TYPE float
|
|
||||||
#define IN_TYPE_SIZE 4
|
|
||||||
#define OUT_TYPE float16_t
|
|
||||||
#define OUT_TYPE_SIZE 2
|
|
||||||
|
|
||||||
layout(local_size_x = 1024) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
|
|
||||||
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
int ne2;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
|
|
||||||
|
|
||||||
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
|
|
||||||
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
|
|
||||||
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
|
|
||||||
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
|
|
||||||
|
|
||||||
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
|
|
||||||
out_[dst_data+i00] = OUT_TYPE(in_[src]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define IN_TYPE float
|
|
||||||
#define IN_TYPE_SIZE 4
|
|
||||||
#define OUT_TYPE float
|
|
||||||
#define OUT_TYPE_SIZE 4
|
|
||||||
|
|
||||||
layout(local_size_x = 1024) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
|
|
||||||
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
int ne2;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
|
|
||||||
|
|
||||||
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
|
|
||||||
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
|
|
||||||
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
|
|
||||||
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
|
|
||||||
|
|
||||||
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
|
|
||||||
out_[dst_data+i00] = OUT_TYPE(in_[src]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,30 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
uint n_past;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i02 = gl_WorkGroupID.z;
|
|
||||||
const uint i01 = gl_WorkGroupID.y;
|
|
||||||
const uint i00 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
|
|
||||||
|
|
||||||
if (i00 > pcs.n_past + i01) {
|
|
||||||
out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
|
|
||||||
} else {
|
|
||||||
out_[index + pcs.outOff] = in_[index + pcs.inOff];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint baseIndex = gl_WorkGroupID.x * 8;
|
|
||||||
|
|
||||||
for (uint x = 0; x < 8; x++) {
|
|
||||||
const uint i = baseIndex + x;
|
|
||||||
const float y = in_[i + pcs.inOff];
|
|
||||||
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
void main() {
|
|
||||||
const uint i = gl_WorkGroupID.x;
|
|
||||||
const int r = inB[i + pcs.inBOff];
|
|
||||||
|
|
||||||
int z = 0;
|
|
||||||
for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
|
|
||||||
const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
|
|
||||||
const mat4 result = dequantize_block(inIndex, ind%NL);
|
|
||||||
for (uint j = 0; j < 4; ++j) {
|
|
||||||
for (uint k = 0; k < 4; ++k) {
|
|
||||||
const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
|
|
||||||
out_[outIndex] = result[j][k];
|
|
||||||
++z;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb01;
|
|
||||||
int nb1;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
|
|
||||||
for (int j = 0; j < k; j++) {
|
|
||||||
out_[y + j] = inA[x + j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i = gl_WorkGroupID.x;
|
|
||||||
const int r = inB[i + pcs.inBOff];
|
|
||||||
|
|
||||||
dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
|
|
||||||
}
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb01;
|
|
||||||
int nb1;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
|
|
||||||
for (int j = 0; j < k; j++) {
|
|
||||||
out_[y + j] = inA[x + j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i = gl_WorkGroupID.x;
|
|
||||||
const int r = inB[i + pcs.inBOff];
|
|
||||||
|
|
||||||
dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
|
|
||||||
}
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define NL 2
|
|
||||||
#define BYTES_FOR_TYPE 4 /*bytes for float*/
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q4_0
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb01;
|
|
||||||
int nb1;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
block_q4_0 get_unaligned_block_q4_0(uint index) {
|
|
||||||
block_q4_0 fres;
|
|
||||||
fres.d = u8BufToFloat16(inA, index);
|
|
||||||
[[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
|
|
||||||
fres.qs[it] = inA[index+2+it];
|
|
||||||
}
|
|
||||||
return fres;
|
|
||||||
}
|
|
||||||
|
|
||||||
mat4 dequantize_block(uint index, uint il) {
|
|
||||||
const block_q4_0 block = get_unaligned_block_q4_0(index);
|
|
||||||
return dequantize_q4_0(block, il);
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "op_getrows.comp"
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define NL 2
|
|
||||||
#define BYTES_FOR_TYPE 4 /*bytes for float*/
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q4_1
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb01;
|
|
||||||
int nb1;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
block_q4_1 get_unaligned_block_q4_1(uint index) {
|
|
||||||
block_q4_1 fres;
|
|
||||||
fres.d = u8BufToFloat16(inA, index);
|
|
||||||
fres.m = u8BufToFloat16(inA, index+2);
|
|
||||||
[[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
|
|
||||||
fres.qs[it] = inA[index+4+it];
|
|
||||||
}
|
|
||||||
return fres;
|
|
||||||
}
|
|
||||||
|
|
||||||
mat4 dequantize_block(uint index, uint il) {
|
|
||||||
const block_q4_1 block = get_unaligned_block_q4_1(index);
|
|
||||||
return dequantize_q4_1(block, il);
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "op_getrows.comp"
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define NL 16
|
|
||||||
#define BYTES_FOR_TYPE 4 /*bytes for float*/
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q6_k
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb01;
|
|
||||||
int nb1;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
block_q6_k get_unaligned_block_q6_k(uint index) {
|
|
||||||
block_q6_k fres;
|
|
||||||
[[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
|
|
||||||
fres.ql[it] = inA[index + it];
|
|
||||||
}
|
|
||||||
[[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
|
|
||||||
fres.qh[it] = inA[index + QK_K/2 + it];
|
|
||||||
}
|
|
||||||
[[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
|
|
||||||
fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
|
|
||||||
}
|
|
||||||
fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
|
|
||||||
return fres;
|
|
||||||
}
|
|
||||||
|
|
||||||
mat4 dequantize_block(uint index, uint il) {
|
|
||||||
const block_q6_k block = get_unaligned_block_q6_k(index);
|
|
||||||
return dequantize_q6_k(block, il);
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "op_getrows.comp"
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1024) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int nb00;
|
|
||||||
int nb01;
|
|
||||||
int nb02;
|
|
||||||
int nb03;
|
|
||||||
int ne10;
|
|
||||||
int ne11;
|
|
||||||
int ne12;
|
|
||||||
int ne13;
|
|
||||||
int nb10;
|
|
||||||
int nb11;
|
|
||||||
int nb12;
|
|
||||||
int nb13;
|
|
||||||
int ne0;
|
|
||||||
int nb0;
|
|
||||||
int nb1;
|
|
||||||
int nb2;
|
|
||||||
int nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const uint i13 = i03 % pcs.ne13;
|
|
||||||
const uint i12 = i02 % pcs.ne12;
|
|
||||||
const uint i11 = i01 % pcs.ne11;
|
|
||||||
|
|
||||||
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
|
|
||||||
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
|
|
||||||
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4);
|
|
||||||
|
|
||||||
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
|
|
||||||
const uint i10 = i0 % pcs.ne10;
|
|
||||||
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
|
||||||
|
|
||||||
layout(local_size_x_id = 0) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne10;
|
|
||||||
int ne11;
|
|
||||||
int ne12;
|
|
||||||
uint nb10;
|
|
||||||
uint nb11;
|
|
||||||
uint nb12;
|
|
||||||
uint nb13;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
uint r2;
|
|
||||||
uint r3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
#define N_F16_F32 4
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint r0 = gl_WorkGroupID.x;
|
|
||||||
const uint rb = gl_WorkGroupID.y*N_F16_F32;
|
|
||||||
const uint im = gl_WorkGroupID.z;
|
|
||||||
|
|
||||||
const uint i12 = im%pcs.ne12;
|
|
||||||
const uint i13 = im/pcs.ne12;
|
|
||||||
|
|
||||||
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
|
|
||||||
|
|
||||||
const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
|
|
||||||
|
|
||||||
for (uint row = 0; row < N_F16_F32; ++row) {
|
|
||||||
uint r1 = rb + row;
|
|
||||||
if (r1 >= pcs.ne11) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
|
|
||||||
sumf += float(inA[x+i]) * float(inB[y+i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float all_sum = subgroupAdd(sumf);
|
|
||||||
if (subgroupElect()) {
|
|
||||||
out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
|
||||||
#extension GL_EXT_debug_printf : enable
|
|
||||||
|
|
||||||
// device subgroup size
|
|
||||||
layout (local_size_x_id = 0) in;
|
|
||||||
|
|
||||||
layout(binding = 0) readonly buffer tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) readonly buffer tensorInB { float inB[]; };
|
|
||||||
layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
int ne11;
|
|
||||||
int ne12;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb11;
|
|
||||||
uint nb12;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
}
|
|
||||||
pcs;
|
|
||||||
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
uvec3 gid = gl_WorkGroupID;
|
|
||||||
|
|
||||||
uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
|
|
||||||
uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
|
|
||||||
|
|
||||||
const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
|
|
||||||
const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
|
|
||||||
float sum = 0.0f;
|
|
||||||
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
|
|
||||||
sum += float(inA[x+i]) * float(inB[y+i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float all_sum = subgroupAdd(sum);
|
|
||||||
if (subgroupElect()) {
|
|
||||||
out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define BLOCKS_IN_QUANT QK4_0
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q4_0
|
|
||||||
#define N_ROWS 4
|
|
||||||
|
|
||||||
#include "op_mul_mv_q_n_pre.comp"
|
|
||||||
|
|
||||||
// The q4_0 version of this function
|
|
||||||
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
|
|
||||||
vec2 acc = vec2(0.0, 0.0);
|
|
||||||
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
|
|
||||||
float d = float(u8BufToFloat16(inA, index));
|
|
||||||
float sumy = 0.0f;
|
|
||||||
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
|
|
||||||
const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
|
|
||||||
|
|
||||||
const float yl0 = inB[yb + i];
|
|
||||||
const float yl1 = inB[yb + i + 1];
|
|
||||||
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
|
|
||||||
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
|
|
||||||
|
|
||||||
sumy += yl0 + yl1 + yl8 + yl9;
|
|
||||||
|
|
||||||
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
|
|
||||||
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
|
|
||||||
}
|
|
||||||
return d * (sumy * -8.f + acc[0] + acc[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "op_mul_mv_q_n.comp"
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define BLOCKS_IN_QUANT QK4_1
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q4_1
|
|
||||||
#define N_ROWS 4
|
|
||||||
|
|
||||||
#include "op_mul_mv_q_n_pre.comp"
|
|
||||||
|
|
||||||
// The q4_1 version of this function
|
|
||||||
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
|
|
||||||
vec2 acc = vec2(0.0, 0.0);
|
|
||||||
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
|
|
||||||
float d = float(u8BufToFloat16(inA, index));
|
|
||||||
float m = float(u8BufToFloat16(inA, index+2));
|
|
||||||
|
|
||||||
float sumy = 0.0f;
|
|
||||||
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
|
|
||||||
const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
|
|
||||||
|
|
||||||
const float yl0 = inB[yb + i];
|
|
||||||
const float yl1 = inB[yb + i + 1];
|
|
||||||
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
|
|
||||||
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
|
|
||||||
|
|
||||||
sumy += yl0 + yl1 + yl8 + yl9;
|
|
||||||
|
|
||||||
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
|
|
||||||
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
|
|
||||||
}
|
|
||||||
return d * (acc[0] + acc[1]) + sumy * m;
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "op_mul_mv_q_n.comp"
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define N_DST 4
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q4_k
|
|
||||||
|
|
||||||
layout(local_size_x = 4) in;
|
|
||||||
layout(local_size_y = 8) in;
|
|
||||||
layout(local_size_z = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne10;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
int ne12;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
uint nb11;
|
|
||||||
uint nb12;
|
|
||||||
uint nb13;
|
|
||||||
uint r2;
|
|
||||||
uint r3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint16_t kmask1 = uint16_t(0x3f3f);
|
|
||||||
const uint16_t kmask2 = uint16_t(0x0f0f);
|
|
||||||
const uint16_t kmask3 = uint16_t(0xc0c0);
|
|
||||||
|
|
||||||
const uint ix = gl_SubgroupInvocationID/8; // 0...3
|
|
||||||
const uint it = gl_SubgroupInvocationID%8; // 0...7
|
|
||||||
const uint iq = it/4; // 0 or 1
|
|
||||||
const uint ir = it%4; // 0...3
|
|
||||||
|
|
||||||
const uint nb = pcs.ne00/QK_K;
|
|
||||||
|
|
||||||
const uint r0 = gl_WorkGroupID.x;
|
|
||||||
const uint r1 = gl_WorkGroupID.y;
|
|
||||||
const uint im = gl_WorkGroupID.z;
|
|
||||||
|
|
||||||
const uint first_row = r0 * N_DST;
|
|
||||||
const uint ib_row = first_row * nb;
|
|
||||||
|
|
||||||
const uint i12 = im%pcs.ne12;
|
|
||||||
const uint i13 = im/pcs.ne12;
|
|
||||||
|
|
||||||
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
|
|
||||||
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;
|
|
||||||
|
|
||||||
const uint xblk = offset0 + pcs.inAOff;
|
|
||||||
const uint y = (offset1 / 4) + pcs.inBOff;
|
|
||||||
|
|
||||||
float yl[16];
|
|
||||||
float yh[16];
|
|
||||||
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
float all_sum = 0.f;
|
|
||||||
|
|
||||||
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
|
||||||
|
|
||||||
for (uint ib = ix; ib < nb; ib += 4) {
|
|
||||||
const uint blk_idx = ib + xblk;
|
|
||||||
|
|
||||||
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
for (int i = 0; i < 8; ++i) {
|
|
||||||
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
|
|
||||||
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
|
|
||||||
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
|
|
||||||
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
|
||||||
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
|
|
||||||
|
|
||||||
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
|
|
||||||
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
|
|
||||||
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
|
|
||||||
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
|
|
||||||
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
|
|
||||||
|
|
||||||
uint16_t sc16[4];
|
|
||||||
sc16[0] = sc_0 & kmask1;
|
|
||||||
sc16[1] = sc_2 & kmask1;
|
|
||||||
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
|
|
||||||
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
|
|
||||||
|
|
||||||
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
for (int i = 0; i < 8; i += 2) {
|
|
||||||
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
|
|
||||||
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
|
|
||||||
acc1[0] += yl[i+0] * (q1 & 0x000F);
|
|
||||||
acc1[1] += yl[i+1] * (q1 & 0x0F00);
|
|
||||||
acc1[2] += yl[i+8] * (q1 & 0x00F0);
|
|
||||||
acc1[3] += yl[i+9] * (q1 & 0xF000);
|
|
||||||
acc2[0] += yh[i+0] * (q2 & 0x000F);
|
|
||||||
acc2[1] += yh[i+1] * (q2 & 0x0F00);
|
|
||||||
acc2[2] += yh[i+8] * (q2 & 0x00F0);
|
|
||||||
acc2[3] += yh[i+9] * (q2 & 0xF000);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
|
|
||||||
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
|
|
||||||
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
|
|
||||||
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
|
|
||||||
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
|
|
||||||
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
|
|
||||||
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
|
|
||||||
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
|
|
||||||
|
|
||||||
float dall = float(inA[blk_idx + row_idx].d);
|
|
||||||
float dmin = float(inA[blk_idx + row_idx].dmin);
|
|
||||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
|
|
||||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
|
|
||||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
|
|
||||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
|
|
||||||
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
|
|
||||||
}
|
|
||||||
|
|
||||||
y4 += 4 * QK_K;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
all_sum = subgroupAdd(sumf[row]);
|
|
||||||
if (subgroupElect()) {
|
|
||||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,106 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define SIZE_OF_BLOCK sizeof_block_q6_k
|
|
||||||
|
|
||||||
layout(local_size_x_id = 0) in;
|
|
||||||
layout(local_size_y_id = 1) in;
|
|
||||||
layout(local_size_z = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne10;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
int ne12;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
uint nb11;
|
|
||||||
uint nb12;
|
|
||||||
uint nb13;
|
|
||||||
uint r2;
|
|
||||||
uint r3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint8_t kmask1 = uint8_t(0x03);
|
|
||||||
const uint8_t kmask2 = uint8_t(0x0C);
|
|
||||||
const uint8_t kmask3 = uint8_t(0x30);
|
|
||||||
const uint8_t kmask4 = uint8_t(0xC0);
|
|
||||||
|
|
||||||
const uint nb = pcs.ne00/QK_K;
|
|
||||||
|
|
||||||
const uint r0 = gl_WorkGroupID.x;
|
|
||||||
const uint r1 = gl_WorkGroupID.y;
|
|
||||||
const uint im = gl_WorkGroupID.z;
|
|
||||||
|
|
||||||
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
|
|
||||||
|
|
||||||
const uint i12 = im%pcs.ne12;
|
|
||||||
const uint i13 = im/pcs.ne12;
|
|
||||||
|
|
||||||
const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
|
|
||||||
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
|
|
||||||
// bits of invocation ID for gl_SubgroupSize=32:
|
|
||||||
// x x x x x
|
|
||||||
// 4 3 2 1 0
|
|
||||||
// ( tid ) ix
|
|
||||||
// ip ( il )
|
|
||||||
|
|
||||||
const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
|
|
||||||
const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
|
|
||||||
const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
|
|
||||||
const uint ip = tid/8; // first or second half of block (0 or 1)
|
|
||||||
const uint il = tid%8; // each half has 8 parts, one per scale
|
|
||||||
const uint n = 4; // 4 scales at a time (and 4 sums)
|
|
||||||
const uint l0 = n*il; // offset into half-block, 0..28
|
|
||||||
const uint is = 8*ip + l0/16; // 0, 1, 8, 9
|
|
||||||
|
|
||||||
const uint y_offset = 128*ip + l0;
|
|
||||||
const uint q_offset_l = 64*ip + l0;
|
|
||||||
const uint q_offset_h = 32*ip + l0;
|
|
||||||
|
|
||||||
for (uint i = ix; i < nb; i += block_stride) {
|
|
||||||
|
|
||||||
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
|
|
||||||
|
|
||||||
const uint qlIndex = q_offset_l;
|
|
||||||
const uint q2Index = qlIndex + QK_K/8;
|
|
||||||
const uint qhIndex = q_offset_h;
|
|
||||||
const uint y = yy + i * QK_K + y_offset;
|
|
||||||
|
|
||||||
float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
||||||
for (uint l = 0; l < n; ++l) {
|
|
||||||
const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
|
|
||||||
const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
|
|
||||||
const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
|
|
||||||
|
|
||||||
sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
|
|
||||||
sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
|
|
||||||
sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32);
|
|
||||||
sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
|
|
||||||
sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
|
|
||||||
}
|
|
||||||
|
|
||||||
const float tot = subgroupAdd(sumf);
|
|
||||||
if (subgroupElect()) {
|
|
||||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#include "op_mul_mv_q_n_pre.comp"
|
|
||||||
|
|
||||||
#define SIZE_OF_D 2
|
|
||||||
|
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
||||||
|
|
||||||
#define NB_Q8_0 8
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
|
|
||||||
if (gl_SubgroupInvocationID > 31)
|
|
||||||
return;
|
|
||||||
|
|
||||||
const int nr = N_DST;
|
|
||||||
const int nsg = N_SIMDGROUP;
|
|
||||||
const int nw = N_SIMDWIDTH;
|
|
||||||
|
|
||||||
const int nb = pcs.ne00/QK8_0;
|
|
||||||
const uint r0 = gl_WorkGroupID.x;
|
|
||||||
const uint r1 = gl_WorkGroupID.y;
|
|
||||||
const uint im = gl_WorkGroupID.z;
|
|
||||||
|
|
||||||
const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
|
|
||||||
|
|
||||||
const uint i12 = im%pcs.ne12;
|
|
||||||
const uint i13 = im/pcs.ne12;
|
|
||||||
|
|
||||||
const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
|
|
||||||
|
|
||||||
const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
|
|
||||||
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
|
|
||||||
|
|
||||||
float yl[NB_Q8_0];
|
|
||||||
float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
|
|
||||||
|
|
||||||
const uint ix = gl_SubgroupInvocationID.x/4;
|
|
||||||
const uint il = gl_SubgroupInvocationID.x%4;
|
|
||||||
|
|
||||||
uint yb = y + ix * QK8_0 + NB_Q8_0*il;
|
|
||||||
|
|
||||||
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
|
||||||
for (uint ib = ix; ib < nb; ib += nw/4) {
|
|
||||||
for (int i = 0; i < NB_Q8_0; ++i) {
|
|
||||||
yl[i] = inB[yb + i];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < nr; row++) {
|
|
||||||
const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
|
|
||||||
float sumq = 0.f;
|
|
||||||
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
|
||||||
const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
|
|
||||||
sumq += qs_iq * yl[iq];
|
|
||||||
}
|
|
||||||
const float16_t d = u8BufToFloat16(inA, x + block_offset);
|
|
||||||
sumf[row] += sumq*d;
|
|
||||||
}
|
|
||||||
|
|
||||||
yb += NB_Q8_0 * nw;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < nr; ++row) {
|
|
||||||
const float tot = subgroupAdd(sumf[row]);
|
|
||||||
if (subgroupElect() && first_row + row < pcs.ne01) {
|
|
||||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
void main() {
|
|
||||||
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
|
|
||||||
if (gl_SubgroupInvocationID > 31)
|
|
||||||
return;
|
|
||||||
|
|
||||||
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
|
|
||||||
|
|
||||||
const uint r0 = gl_WorkGroupID.x;
|
|
||||||
const uint r1 = gl_WorkGroupID.y;
|
|
||||||
const uint im = gl_WorkGroupID.z;
|
|
||||||
|
|
||||||
const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
|
|
||||||
|
|
||||||
const uint i12 = im%pcs.ne12;
|
|
||||||
const uint i13 = im/pcs.ne12;
|
|
||||||
|
|
||||||
// pointers to src0 rows
|
|
||||||
uint ax[N_ROWS];
|
|
||||||
for (int row = 0; row < N_ROWS; ++row) {
|
|
||||||
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
|
|
||||||
|
|
||||||
ax[row] = offset0 + pcs.inAOff;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
|
|
||||||
|
|
||||||
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
||||||
|
|
||||||
const uint ix = gl_SubgroupInvocationID/2;
|
|
||||||
const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
|
|
||||||
|
|
||||||
uint yb = y + ix * BLOCKS_IN_QUANT + il;
|
|
||||||
|
|
||||||
//debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
|
|
||||||
// gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
|
|
||||||
// gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
|
|
||||||
|
|
||||||
for (uint ib = ix; ib < nb; ib += 16) {
|
|
||||||
for (int row = 0; row < N_ROWS; row++) {
|
|
||||||
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
|
|
||||||
}
|
|
||||||
|
|
||||||
yb += BLOCKS_IN_QUANT * 16;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_ROWS; ++row) {
|
|
||||||
const float tot = subgroupAdd(sumf[row]);
|
|
||||||
if (first_row + row < pcs.ne01 && subgroupElect()) {
|
|
||||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
layout(local_size_x_id = 0) in;
|
|
||||||
layout(local_size_y = 8) in;
|
|
||||||
layout(local_size_z = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
|
|
||||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
|
||||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
int ne10;
|
|
||||||
int ne12;
|
|
||||||
int ne0;
|
|
||||||
int ne1;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
uint nb11;
|
|
||||||
uint nb12;
|
|
||||||
uint nb13;
|
|
||||||
uint r2;
|
|
||||||
uint r3;
|
|
||||||
} pcs;
|
|
||||||
|
|
@ -1,84 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 256) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
uint ne00;
|
|
||||||
uint nb01;
|
|
||||||
float eps;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
shared float sum[gl_WorkGroupSize.x];
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
|
|
||||||
// MEAN
|
|
||||||
// parallel sum
|
|
||||||
sum[gl_LocalInvocationID.x] = 0.0;
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
sum[gl_LocalInvocationID.x] += in_[x+i00];
|
|
||||||
}
|
|
||||||
|
|
||||||
// reduce
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
|
|
||||||
if (gl_LocalInvocationID.x < i) {
|
|
||||||
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcast
|
|
||||||
if (gl_LocalInvocationID.x == 0) {
|
|
||||||
sum[0] /= float(pcs.ne00);
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
const float mean = sum[0];
|
|
||||||
|
|
||||||
// recenter
|
|
||||||
const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
out_[y+i00] = in_[x+i00] - mean;
|
|
||||||
}
|
|
||||||
|
|
||||||
// VARIANCE
|
|
||||||
// parallel sum
|
|
||||||
sum[gl_LocalInvocationID.x] = 0.0;
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
|
|
||||||
}
|
|
||||||
|
|
||||||
// reduce
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
|
|
||||||
if (gl_LocalInvocationID.x < i) {
|
|
||||||
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcast
|
|
||||||
if (gl_LocalInvocationID.x == 0) {
|
|
||||||
sum[0] /= float(pcs.ne00);
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
const float variance = sum[0];
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + pcs.eps);
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
out_[y+i00] *= scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
|
||||||
|
|
||||||
for (uint x = 0; x < 4; x++) {
|
|
||||||
const uint i = baseIndex + x;
|
|
||||||
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,53 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 512) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
uint ne00;
|
|
||||||
uint nb01;
|
|
||||||
float eps;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
shared float sum[gl_WorkGroupSize.x];
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
|
|
||||||
|
|
||||||
// parallel sum
|
|
||||||
sum[gl_LocalInvocationID.x] = 0.0;
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
|
|
||||||
}
|
|
||||||
|
|
||||||
// reduce
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
|
|
||||||
if (gl_LocalInvocationID.x < i) {
|
|
||||||
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcast
|
|
||||||
if (gl_LocalInvocationID.x == 0) {
|
|
||||||
sum[0] /= float(pcs.ne00);
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
memoryBarrierShared();
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
|
|
||||||
|
|
||||||
const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
|
|
||||||
out_[y+i00] = in_[x+i00] * scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "rope_common.comp"
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
|
|
||||||
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i3 = gl_WorkGroupID.z;
|
|
||||||
const uint i2 = gl_WorkGroupID.y;
|
|
||||||
const uint i1 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
float corr_dims[2];
|
|
||||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
|
||||||
|
|
||||||
float theta_base = float(inB[pcs.inBOff + i2]);
|
|
||||||
float inv_ndims = -1.f/pcs.n_dims;
|
|
||||||
|
|
||||||
float cos_theta;
|
|
||||||
float sin_theta;
|
|
||||||
|
|
||||||
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
|
|
||||||
if (i0 < pcs.n_dims) {
|
|
||||||
uint ic = i0/2;
|
|
||||||
|
|
||||||
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
|
|
||||||
|
|
||||||
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
|
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
const float x0 = float(inA[src]);
|
|
||||||
const float x1 = float(inA[src+pcs.n_dims/2]);
|
|
||||||
|
|
||||||
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
|
|
||||||
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
|
|
||||||
} else {
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
out_[dst_data] = inA[src];
|
|
||||||
out_[dst_data+1] = inA[src+1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "rope_common.comp"
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
|
|
||||||
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i3 = gl_WorkGroupID.z;
|
|
||||||
const uint i2 = gl_WorkGroupID.y;
|
|
||||||
const uint i1 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
float corr_dims[2];
|
|
||||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
|
||||||
|
|
||||||
float theta_base = float(inB[pcs.inBOff + i2]);
|
|
||||||
float inv_ndims = -1.f/pcs.n_dims;
|
|
||||||
|
|
||||||
float cos_theta;
|
|
||||||
float sin_theta;
|
|
||||||
|
|
||||||
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
|
|
||||||
if (i0 < pcs.n_dims) {
|
|
||||||
uint ic = i0/2;
|
|
||||||
|
|
||||||
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
|
|
||||||
|
|
||||||
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
|
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
const float x0 = inA[src];
|
|
||||||
const float x1 = inA[src+pcs.n_dims/2];
|
|
||||||
|
|
||||||
out_[dst_data] = x0*cos_theta - x1*sin_theta;
|
|
||||||
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
||||||
} else {
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
out_[dst_data] = inA[src];
|
|
||||||
out_[dst_data+1] = inA[src+1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "rope_common.comp"
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
|
|
||||||
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i3 = gl_WorkGroupID.z;
|
|
||||||
const uint i2 = gl_WorkGroupID.y;
|
|
||||||
const uint i1 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
float corr_dims[2];
|
|
||||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
|
||||||
|
|
||||||
float theta_base = float(inB[pcs.inBOff + i2]);
|
|
||||||
float inv_ndims = -1.f/pcs.n_dims;
|
|
||||||
|
|
||||||
float cos_theta;
|
|
||||||
float sin_theta;
|
|
||||||
|
|
||||||
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
|
|
||||||
if (i0 < pcs.n_dims) {
|
|
||||||
uint ic = i0/2;
|
|
||||||
|
|
||||||
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
|
|
||||||
|
|
||||||
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
|
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
const float x0 = float(inA[src]);
|
|
||||||
const float x1 = float(inA[src+1]);
|
|
||||||
|
|
||||||
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
|
|
||||||
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
|
|
||||||
} else {
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
out_[dst_data] = inA[src];
|
|
||||||
out_[dst_data+1] = inA[src+1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "rope_common.comp"
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
|
|
||||||
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i3 = gl_WorkGroupID.z;
|
|
||||||
const uint i2 = gl_WorkGroupID.y;
|
|
||||||
const uint i1 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
float corr_dims[2];
|
|
||||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
|
||||||
|
|
||||||
float theta_base = float(inB[pcs.inBOff + i2]);
|
|
||||||
float inv_ndims = -1.f/pcs.n_dims;
|
|
||||||
|
|
||||||
float cos_theta;
|
|
||||||
float sin_theta;
|
|
||||||
|
|
||||||
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
|
|
||||||
if (i0 < pcs.n_dims) {
|
|
||||||
uint ic = i0/2;
|
|
||||||
|
|
||||||
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
|
|
||||||
|
|
||||||
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
|
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
const float x0 = inA[src];
|
|
||||||
const float x1 = inA[src+1];
|
|
||||||
|
|
||||||
out_[dst_data] = x0*cos_theta - x1*sin_theta;
|
|
||||||
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
|
|
||||||
} else {
|
|
||||||
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
|
|
||||||
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
out_[dst_data] = inA[src];
|
|
||||||
out_[dst_data+1] = inA[src+1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
float scale;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint i = gl_WorkGroupID.x;
|
|
||||||
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
|
|
||||||
}
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
float scale;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint baseIndex = gl_WorkGroupID.x * 8;
|
|
||||||
|
|
||||||
for (uint x = 0; x < 8; x++) {
|
|
||||||
const uint i = baseIndex + x;
|
|
||||||
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
|
|
||||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inOff;
|
|
||||||
uint outOff;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
|
||||||
|
|
||||||
for (uint x = 0; x < 4; x++) {
|
|
||||||
const uint i = baseIndex + x;
|
|
||||||
const float y = in_[i + pcs.inOff];
|
|
||||||
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
|
|
||||||
|
|
||||||
#version 450
|
|
||||||
|
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
layout(local_size_x_id = 0) in;
|
|
||||||
|
|
||||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
|
||||||
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
|
|
||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
|
||||||
|
|
||||||
layout(push_constant) uniform PushConstants {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint outOff;
|
|
||||||
int ne00;
|
|
||||||
int ne01;
|
|
||||||
int ne02;
|
|
||||||
float scale;
|
|
||||||
float max_bias;
|
|
||||||
float m0;
|
|
||||||
float m1;
|
|
||||||
uint n_head_log2;
|
|
||||||
int mask;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
if (gl_SubgroupInvocationID > 31)
|
|
||||||
return;
|
|
||||||
|
|
||||||
const uint i03 = gl_WorkGroupID.z;
|
|
||||||
const uint i02 = gl_WorkGroupID.y;
|
|
||||||
const uint i01 = gl_WorkGroupID.x;
|
|
||||||
|
|
||||||
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
|
|
||||||
const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
|
|
||||||
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
|
|
||||||
const uint pdst = extra_off + pcs.outOff; // Based from out_
|
|
||||||
|
|
||||||
float slope = 1.0f;
|
|
||||||
|
|
||||||
// ALiBi
|
|
||||||
if (pcs.max_bias > 0.0f) {
|
|
||||||
int64_t h = i02;
|
|
||||||
|
|
||||||
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
|
|
||||||
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
|
|
||||||
|
|
||||||
slope = pow(base, float(exp));
|
|
||||||
}
|
|
||||||
|
|
||||||
// parallel max
|
|
||||||
float localMax = uintBitsToFloat(0xFF800000);
|
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
|
||||||
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
|
|
||||||
}
|
|
||||||
float max_ = subgroupMax(localMax);
|
|
||||||
|
|
||||||
// parallel sum
|
|
||||||
float localSum = 0.0f;
|
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
|
||||||
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
|
|
||||||
localSum += exp_psrc0;
|
|
||||||
out_[pdst + i00] = exp_psrc0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float sum = subgroupAdd(localSum);
|
|
||||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
|
||||||
out_[pdst + i00] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,71 +0,0 @@
|
||||||
#include "common.comp"
|
|
||||||
|
|
||||||
#define GGML_ROPE_TYPE_NEOX 2
|
|
||||||
|
|
||||||
// TODO: use a local size of 32 or more (Metal uses 1024)
|
|
||||||
layout(local_size_x = 1) in;
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
|
||||||
uint inAOff;
|
|
||||||
uint inBOff;
|
|
||||||
uint inCOff;
|
|
||||||
uint outOff;
|
|
||||||
int n_dims;
|
|
||||||
int mode;
|
|
||||||
int n_ctx_orig;
|
|
||||||
float freq_base;
|
|
||||||
float freq_scale;
|
|
||||||
bool has_freq_factors;
|
|
||||||
float ext_factor;
|
|
||||||
float attn_factor;
|
|
||||||
float beta_fast;
|
|
||||||
float beta_slow;
|
|
||||||
uint nb00;
|
|
||||||
uint nb01;
|
|
||||||
uint nb02;
|
|
||||||
uint nb03;
|
|
||||||
int ne0;
|
|
||||||
uint nb0;
|
|
||||||
uint nb1;
|
|
||||||
uint nb2;
|
|
||||||
uint nb3;
|
|
||||||
} pcs;
|
|
||||||
|
|
||||||
float rope_yarn_ramp(const float low, const float high, const float i0) {
|
|
||||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
||||||
return 1.0f - min(1.0f, max(0.0f, y));
|
|
||||||
}
|
|
||||||
|
|
||||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
|
||||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
||||||
void rope_yarn(
|
|
||||||
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
|
|
||||||
out float cos_theta, out float sin_theta
|
|
||||||
) {
|
|
||||||
// Get n-d rotational scaling corrected for extrapolation
|
|
||||||
float theta_interp = freq_scale * theta_extrap;
|
|
||||||
float theta = theta_interp;
|
|
||||||
if (ext_factor != 0.0f) {
|
|
||||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
||||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
||||||
|
|
||||||
// Get n-d magnitude scaling corrected for interpolation
|
|
||||||
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
|
||||||
}
|
|
||||||
cos_theta = cos(theta) * mscale;
|
|
||||||
sin_theta = sin(theta) * mscale;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
||||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
||||||
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
|
||||||
return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
|
|
||||||
}
|
|
||||||
|
|
||||||
void rope_yarn_corr_dims(
|
|
||||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
|
|
||||||
) {
|
|
||||||
// start and end correction dims
|
|
||||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
|
||||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
|
||||||
}
|
|
||||||
|
|
@ -230,8 +230,10 @@ typedef struct {
|
||||||
uint64_t nb22;
|
uint64_t nb22;
|
||||||
uint64_t nb23;
|
uint64_t nb23;
|
||||||
int32_t ne32;
|
int32_t ne32;
|
||||||
|
int32_t ne33;
|
||||||
uint64_t nb31;
|
uint64_t nb31;
|
||||||
uint64_t nb32;
|
uint64_t nb32;
|
||||||
|
uint64_t nb33;
|
||||||
int32_t ne1;
|
int32_t ne1;
|
||||||
int32_t ne2;
|
int32_t ne2;
|
||||||
float scale;
|
float scale;
|
||||||
|
|
|
||||||
|
|
@ -5018,8 +5018,10 @@ static bool ggml_metal_encode_node(
|
||||||
/*.nb22 =*/ nb22,
|
/*.nb22 =*/ nb22,
|
||||||
/*.nb23 =*/ nb23,
|
/*.nb23 =*/ nb23,
|
||||||
/*.ne32 =*/ ne32,
|
/*.ne32 =*/ ne32,
|
||||||
|
/*.ne33 =*/ ne33,
|
||||||
/*.nb31 =*/ nb31,
|
/*.nb31 =*/ nb31,
|
||||||
/*.nb32 =*/ nb32,
|
/*.nb32 =*/ nb32,
|
||||||
|
/*.nb33 =*/ nb33,
|
||||||
/*.ne1 =*/ ne1,
|
/*.ne1 =*/ ne1,
|
||||||
/*.ne2 =*/ ne2,
|
/*.ne2 =*/ ne2,
|
||||||
/*.scale =*/ scale,
|
/*.scale =*/ scale,
|
||||||
|
|
|
||||||
|
|
@ -3857,7 +3857,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
// load the mask in shared memory
|
// load the mask in shared memory
|
||||||
#pragma unroll(Q)
|
#pragma unroll(Q)
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
||||||
|
|
||||||
const float m = pm[ic + tiisg];
|
const float m = pm[ic + tiisg];
|
||||||
|
|
||||||
|
|
@ -4343,7 +4343,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
const bool has_mask = mask != q;
|
const bool has_mask = mask != q;
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
|
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
case GGML_OP_SET_ROWS:
|
||||||
|
{
|
||||||
|
// TODO: add support
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
|
@ -5757,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||||
|
|
||||||
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
||||||
|
|
||||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
const int ne00 = src0->ne[0];
|
||||||
const int ne01 = src0 ? src0->ne[1] : 0;
|
const int ne01 = src0->ne[1];
|
||||||
const int ne02 = src0 ? src0->ne[2] : 0;
|
const int ne02 = src0->ne[2];
|
||||||
const int ne03 = src0 ? src0->ne[3] : 0;
|
const int ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const cl_long nb01 = src0->nb[1];
|
||||||
|
const cl_long nb02 = src0->nb[2];
|
||||||
|
const cl_long nb03 = src0->nb[3];
|
||||||
|
|
||||||
|
const int ne12 = src1 ? src1->ne[2] : 0;
|
||||||
|
const int ne13 = src1 ? src1->ne[3] : 0;
|
||||||
|
|
||||||
|
const cl_long nb11 = src1 ? src1->nb[1] : 0;
|
||||||
|
const cl_long nb12 = src1 ? src1->nb[2] : 0;
|
||||||
|
const cl_long nb13 = src1 ? src1->nb[3] : 0;
|
||||||
|
|
||||||
|
const cl_long nb1 = dst->nb[1];
|
||||||
|
const cl_long nb2 = dst->nb[2];
|
||||||
|
const cl_long nb3 = dst->nb[3];
|
||||||
|
|
||||||
float scale, max_bias;
|
float scale, max_bias;
|
||||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
const int nrows_x = ggml_nrows(src0);
|
const int n_head = src0->ne[2];
|
||||||
const int nrows_y = src0->ne[1];
|
|
||||||
|
|
||||||
const int n_head = nrows_x/nrows_y;
|
|
||||||
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
|
@ -5814,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
|
||||||
|
|
||||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||||
|
|
|
||||||
|
|
@ -22,32 +22,45 @@
|
||||||
REQD_SUBGROUP_SIZE_64
|
REQD_SUBGROUP_SIZE_64
|
||||||
#endif
|
#endif
|
||||||
kernel void kernel_soft_max_4_f16(
|
kernel void kernel_soft_max_4_f16(
|
||||||
global float * src0,
|
global char * src0,
|
||||||
ulong offset0,
|
ulong offset0,
|
||||||
global half * src1,
|
global char * src1,
|
||||||
ulong offset1,
|
ulong offset1,
|
||||||
global float * dst,
|
global char * dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
int ne00,
|
int ne00,
|
||||||
int ne01,
|
ulong nb01,
|
||||||
int ne02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2,
|
||||||
|
ulong nb3,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias,
|
float max_bias,
|
||||||
float m0,
|
float m0,
|
||||||
float m1,
|
float m1,
|
||||||
int n_head_log2
|
int n_head_log2
|
||||||
) {
|
) {
|
||||||
src0 = (global float *)((global char *)src0 + offset0);
|
src0 = src0 + offset0;
|
||||||
src1 = (global half *)((global char *)src1 + offset1);
|
src1 = src1 + offset1;
|
||||||
dst = (global float *)((global char *)dst + offsetd);
|
dst = dst + offsetd;
|
||||||
|
|
||||||
int i03 = get_group_id(2);
|
int i03 = get_group_id(2);
|
||||||
int i02 = get_group_id(1);
|
int i02 = get_group_id(1);
|
||||||
int i01 = get_group_id(0);
|
int i01 = get_group_id(0);
|
||||||
|
|
||||||
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
int i13 = i03%ne13;
|
||||||
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
|
int i12 = i02%ne12;
|
||||||
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
int i11 = i01;
|
||||||
|
|
||||||
|
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||||
|
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,32 +22,45 @@
|
||||||
REQD_SUBGROUP_SIZE_64
|
REQD_SUBGROUP_SIZE_64
|
||||||
#endif
|
#endif
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
global float * src0,
|
global char * src0,
|
||||||
ulong offset0,
|
ulong offset0,
|
||||||
global float * src1,
|
global char * src1,
|
||||||
ulong offset1,
|
ulong offset1,
|
||||||
global float * dst,
|
global char * dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
int ne00,
|
int ne00,
|
||||||
int ne01,
|
ulong nb01,
|
||||||
int ne02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2,
|
||||||
|
ulong nb3,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias,
|
float max_bias,
|
||||||
float m0,
|
float m0,
|
||||||
float m1,
|
float m1,
|
||||||
int n_head_log2
|
int n_head_log2
|
||||||
) {
|
) {
|
||||||
src0 = (global float*)((global char*)src0 + offset0);
|
src0 = src0 + offset0;
|
||||||
src1 = (global float*)((global char*)src1 + offset1);
|
src1 = src1 + offset1;
|
||||||
dst = (global float*)((global char*)dst + offsetd);
|
dst = dst + offsetd;
|
||||||
|
|
||||||
int i03 = get_group_id(2);
|
int i03 = get_group_id(2);
|
||||||
int i02 = get_group_id(1);
|
int i02 = get_group_id(1);
|
||||||
int i01 = get_group_id(0);
|
int i01 = get_group_id(0);
|
||||||
|
|
||||||
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
int i13 = i03%ne13;
|
||||||
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
|
int i12 = i02%ne12;
|
||||||
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
int i11 = i01;
|
||||||
|
|
||||||
|
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||||
|
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,32 +22,45 @@
|
||||||
REQD_SUBGROUP_SIZE_64
|
REQD_SUBGROUP_SIZE_64
|
||||||
#endif
|
#endif
|
||||||
kernel void kernel_soft_max_f16(
|
kernel void kernel_soft_max_f16(
|
||||||
global float * src0,
|
global char * src0,
|
||||||
ulong offset0,
|
ulong offset0,
|
||||||
global half * src1,
|
global char * src1,
|
||||||
ulong offset1,
|
ulong offset1,
|
||||||
global float * dst,
|
global char * dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
int ne00,
|
int ne00,
|
||||||
int ne01,
|
ulong nb01,
|
||||||
int ne02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2,
|
||||||
|
ulong nb3,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias,
|
float max_bias,
|
||||||
float m0,
|
float m0,
|
||||||
float m1,
|
float m1,
|
||||||
int n_head_log2
|
int n_head_log2
|
||||||
) {
|
) {
|
||||||
src0 = (global float *)((global char *)src0 + offset0);
|
src0 = src0 + offset0;
|
||||||
src1 = (global half *)((global char *)src1 + offset1);
|
src1 = src1 + offset1;
|
||||||
dst = (global float *)((global char *)dst + offsetd);
|
dst = dst + offsetd;
|
||||||
|
|
||||||
int i03 = get_group_id(2);
|
int i03 = get_group_id(2);
|
||||||
int i02 = get_group_id(1);
|
int i02 = get_group_id(1);
|
||||||
int i01 = get_group_id(0);
|
int i01 = get_group_id(0);
|
||||||
|
|
||||||
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
int i13 = i03%ne13;
|
||||||
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
|
int i12 = i02%ne12;
|
||||||
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
int i11 = i01;
|
||||||
|
|
||||||
|
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||||
|
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,32 +22,45 @@
|
||||||
REQD_SUBGROUP_SIZE_64
|
REQD_SUBGROUP_SIZE_64
|
||||||
#endif
|
#endif
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
global float * src0,
|
global char * src0,
|
||||||
ulong offset0,
|
ulong offset0,
|
||||||
global float * src1,
|
global char * src1,
|
||||||
ulong offset1,
|
ulong offset1,
|
||||||
global float * dst,
|
global char * dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
int ne00,
|
int ne00,
|
||||||
int ne01,
|
ulong nb01,
|
||||||
int ne02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2,
|
||||||
|
ulong nb3,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias,
|
float max_bias,
|
||||||
float m0,
|
float m0,
|
||||||
float m1,
|
float m1,
|
||||||
int n_head_log2
|
int n_head_log2
|
||||||
) {
|
) {
|
||||||
src0 = (global float*)((global char*)src0 + offset0);
|
src0 = src0 + offset0;
|
||||||
src1 = (global float*)((global char*)src1 + offset1);
|
src1 = src1 + offset1;
|
||||||
dst = (global float*)((global char*)dst + offsetd);
|
dst = dst + offsetd;
|
||||||
|
|
||||||
int i03 = get_group_id(2);
|
int i03 = get_group_id(2);
|
||||||
int i02 = get_group_id(1);
|
int i02 = get_group_id(1);
|
||||||
int i01 = get_group_id(0);
|
int i01 = get_group_id(0);
|
||||||
|
|
||||||
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
int i13 = i03%ne13;
|
||||||
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
|
int i12 = i02%ne12;
|
||||||
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
int i11 = i01;
|
||||||
|
|
||||||
|
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||||
|
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||||
|
|
||||||
info.devices[i].cc =
|
info.devices[i].cc =
|
||||||
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
||||||
info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
||||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case GGML_OP_SET_ROWS:
|
||||||
|
{
|
||||||
|
// TODO: add support
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
{
|
{
|
||||||
ggml_type src0_type = op->src[0]->type;
|
ggml_type src0_type = op->src[0]->type;
|
||||||
|
|
|
||||||
|
|
@ -224,6 +224,21 @@ enum vk_device_architecture {
|
||||||
INTEL_XE2,
|
INTEL_XE2,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// HSK x HSV
|
||||||
|
enum FaHeadSizes {
|
||||||
|
FA_HEAD_SIZE_64,
|
||||||
|
FA_HEAD_SIZE_80,
|
||||||
|
FA_HEAD_SIZE_96,
|
||||||
|
FA_HEAD_SIZE_112,
|
||||||
|
FA_HEAD_SIZE_128,
|
||||||
|
FA_HEAD_SIZE_192,
|
||||||
|
FA_HEAD_SIZE_192_128,
|
||||||
|
FA_HEAD_SIZE_256,
|
||||||
|
FA_HEAD_SIZE_576_512,
|
||||||
|
FA_HEAD_SIZE_UNSUPPORTED,
|
||||||
|
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
||||||
|
};
|
||||||
|
|
||||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||||
vk::PhysicalDeviceProperties props = device.getProperties();
|
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||||
|
|
||||||
|
|
@ -467,26 +482,11 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||||
|
|
||||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||||
|
|
||||||
|
|
@ -1003,7 +1003,7 @@ struct ggml_backend_vk_context {
|
||||||
|
|
||||||
// number of additional consecutive nodes that are being fused with the
|
// number of additional consecutive nodes that are being fused with the
|
||||||
// node currently being processed
|
// node currently being processed
|
||||||
uint32_t num_additional_fused_ops {};
|
int num_additional_fused_ops {};
|
||||||
};
|
};
|
||||||
|
|
||||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||||
|
|
@ -1699,6 +1699,35 @@ enum FaCodePath {
|
||||||
FA_COOPMAT2,
|
FA_COOPMAT2,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
||||||
|
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
||||||
|
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
switch (hsk) {
|
||||||
|
case 64: return FA_HEAD_SIZE_64;
|
||||||
|
case 80: return FA_HEAD_SIZE_80;
|
||||||
|
case 96: return FA_HEAD_SIZE_96;
|
||||||
|
case 112: return FA_HEAD_SIZE_112;
|
||||||
|
case 128: return FA_HEAD_SIZE_128;
|
||||||
|
case 192:
|
||||||
|
if (hsv == 192) {
|
||||||
|
return FA_HEAD_SIZE_192;
|
||||||
|
} else if (hsv == 128) {
|
||||||
|
return FA_HEAD_SIZE_192_128;
|
||||||
|
} else {
|
||||||
|
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
case 256: return FA_HEAD_SIZE_256;
|
||||||
|
case 576:
|
||||||
|
if (hsv == 512) {
|
||||||
|
return FA_HEAD_SIZE_576_512;
|
||||||
|
} else {
|
||||||
|
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// number of rows/cols for flash attention shader
|
// number of rows/cols for flash attention shader
|
||||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||||
|
|
@ -1719,8 +1748,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||||
GGML_UNUSED(clamp);
|
GGML_UNUSED(clamp);
|
||||||
|
GGML_UNUSED(hsv);
|
||||||
|
|
||||||
if (path == FA_SCALAR) {
|
if (path == FA_SCALAR) {
|
||||||
if (small_rows) {
|
if (small_rows) {
|
||||||
|
|
@ -1744,7 +1774,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
||||||
}
|
}
|
||||||
|
|
||||||
// small cols to reduce register count
|
// small cols to reduce register count
|
||||||
if (ggml_is_quantized(type) || D == 256) {
|
if (ggml_is_quantized(type) || hsk >= 256) {
|
||||||
return {64, 32};
|
return {64, 32};
|
||||||
}
|
}
|
||||||
return {64, 64};
|
return {64, 64};
|
||||||
|
|
@ -2037,19 +2067,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||||
};
|
};
|
||||||
|
|
||||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||||
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
||||||
};
|
};
|
||||||
|
|
||||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||||
// For large number of rows, 128 invocations seems to work best.
|
// For large number of rows, 128 invocations seems to work best.
|
||||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||||
// can't use 256 for D==80.
|
// can't use 256 for D==80.
|
||||||
// For scalar, use 128 (arbitrary)
|
// For scalar, use 128 (arbitrary)
|
||||||
|
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||||
|
const uint32_t D = (hsk|hsv);
|
||||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||||
? scalar_flash_attention_workgroup_size
|
? scalar_flash_attention_workgroup_size
|
||||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||||
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
||||||
|
|
||||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||||
|
|
@ -2058,26 +2090,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
||||||
};
|
};
|
||||||
|
|
||||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
|
|
||||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
||||||
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
||||||
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
||||||
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
||||||
|
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||||
|
|
@ -3688,7 +3723,6 @@ static void ggml_vk_instance_init() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
|
||||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||||
|
|
||||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||||
|
|
@ -6002,24 +6036,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
||||||
// Needs to be kept up to date on shader changes
|
// Needs to be kept up to date on shader changes
|
||||||
|
GGML_UNUSED(hsv);
|
||||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||||
|
|
||||||
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||||
|
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||||
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||||
|
|
||||||
|
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||||
|
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||||
|
// Needs to be kept up to date on shader changes
|
||||||
|
GGML_UNUSED(hsv);
|
||||||
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||||
|
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
||||||
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||||
|
|
||||||
const uint32_t acctype = f32acc ? 4 : 2;
|
const uint32_t acctype = f32acc ? 4 : 2;
|
||||||
const uint32_t f16vec4 = 8;
|
const uint32_t f16vec4 = 8;
|
||||||
|
|
||||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||||
|
|
||||||
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
||||||
|
|
||||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||||
|
|
||||||
const uint32_t kshstride = D / 4 + 2;
|
const uint32_t kshstride = hsk / 4 + 2;
|
||||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||||
|
|
||||||
const uint32_t slope = Br * sizeof(float);
|
const uint32_t slope = Br * sizeof(float);
|
||||||
|
|
@ -6027,7 +6084,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||||
|
|
||||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||||
|
|
||||||
return supported;
|
return supported;
|
||||||
}
|
}
|
||||||
|
|
@ -6051,11 +6108,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||||
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
||||||
|
|
||||||
const uint32_t D = neq0;
|
const uint32_t HSK = nek0;
|
||||||
|
const uint32_t HSV = nev0;
|
||||||
uint32_t N = neq1;
|
uint32_t N = neq1;
|
||||||
const uint32_t KV = nek1;
|
const uint32_t KV = nek1;
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == D);
|
GGML_ASSERT(ne0 == HSV);
|
||||||
GGML_ASSERT(ne2 == N);
|
GGML_ASSERT(ne2 == N);
|
||||||
|
|
||||||
// input tensor rows must be contiguous
|
// input tensor rows must be contiguous
|
||||||
|
|
@ -6063,12 +6121,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||||
|
|
||||||
GGML_ASSERT(neq0 == D);
|
GGML_ASSERT(neq0 == HSK);
|
||||||
GGML_ASSERT(nek0 == D);
|
|
||||||
GGML_ASSERT(nev0 == D);
|
|
||||||
|
|
||||||
GGML_ASSERT(neq1 == N);
|
GGML_ASSERT(neq1 == N);
|
||||||
GGML_ASSERT(nev0 == D);
|
|
||||||
|
|
||||||
GGML_ASSERT(nev1 == nek1);
|
GGML_ASSERT(nev1 == nek1);
|
||||||
|
|
||||||
|
|
@ -6089,7 +6144,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||||
|
|
||||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
||||||
|
|
||||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||||
path = FA_SCALAR;
|
path = FA_SCALAR;
|
||||||
|
|
@ -6142,47 +6197,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
path = FA_SCALAR;
|
path = FA_SCALAR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||||
|
if (path == FA_SCALAR &&
|
||||||
|
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
||||||
|
small_rows = true;
|
||||||
|
}
|
||||||
|
|
||||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||||
|
|
||||||
|
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
||||||
|
|
||||||
switch (path) {
|
switch (path) {
|
||||||
case FA_SCALAR:
|
case FA_SCALAR:
|
||||||
switch (D) {
|
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unsupported D value");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case FA_COOPMAT1:
|
case FA_COOPMAT1:
|
||||||
switch (D) {
|
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unsupported D value");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case FA_COOPMAT2:
|
case FA_COOPMAT2:
|
||||||
switch (D) {
|
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unsupported D value");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(0);
|
GGML_ASSERT(0);
|
||||||
|
|
@ -6212,7 +6245,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
// Try to use split_k when KV is large enough to be worth the overhead
|
// Try to use split_k when KV is large enough to be worth the overhead
|
||||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||||
// Try to run two workgroups per SM.
|
// Try to run two workgroups per SM.
|
||||||
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
|
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||||
if (split_k > 1) {
|
if (split_k > 1) {
|
||||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||||
// of "align", so recompute split_k based on that.
|
// of "align", so recompute split_k based on that.
|
||||||
|
|
@ -6224,7 +6257,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
|
|
||||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||||
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||||
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
||||||
GGML_ABORT("Requested preallocation size is too large");
|
GGML_ABORT("Requested preallocation size is too large");
|
||||||
}
|
}
|
||||||
|
|
@ -6342,7 +6375,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||||
{
|
{
|
||||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||||
|
|
@ -10241,19 +10274,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
auto device = ggml_vk_get_device(ctx->device);
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
bool coopmat2 = device->coopmat2;
|
bool coopmat2 = device->coopmat2;
|
||||||
switch (op->src[0]->ne[0]) {
|
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
||||||
case 64:
|
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
||||||
case 80:
|
|
||||||
case 96:
|
|
||||||
case 112:
|
|
||||||
case 128:
|
|
||||||
case 256:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
||||||
// different head sizes of K and V are not supported yet
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||||
|
|
@ -10265,6 +10287,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// TODO: support broadcast
|
||||||
|
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
|
||||||
|
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||||
|
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// It's straightforward to support different K/V dequant, but would
|
// It's straightforward to support different K/V dequant, but would
|
||||||
// significantly increase the number of pipelines
|
// significantly increase the number of pipelines
|
||||||
if (op->src[1]->type != op->src[2]->type) {
|
if (op->src[1]->type != op->src[2]->type) {
|
||||||
|
|
@ -10333,6 +10361,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SET_ROWS:
|
||||||
|
{
|
||||||
|
// TODO: add support
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,8 @@
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
#include "flash_attn_base.comp"
|
#include "flash_attn_base.comp"
|
||||||
|
|
||||||
const uint32_t D_per_thread = D / D_split;
|
const uint32_t HSK_per_thread = HSK / D_split;
|
||||||
|
const uint32_t HSV_per_thread = HSV / D_split;
|
||||||
|
|
||||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||||
|
|
@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
{
|
{
|
||||||
uint32_t offset = (iq2 + r) * D + c;
|
uint32_t offset = (iq2 + r) * HSV + c;
|
||||||
data_o[o_offset + offset] = D_TYPE(elem);
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
return elem;
|
return elem;
|
||||||
}
|
}
|
||||||
|
|
@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||||
shared vec4 tmpshv4[WorkGroupSize];
|
shared vec4 tmpshv4[WorkGroupSize];
|
||||||
|
|
||||||
shared float masksh[Bc][Br];
|
shared float masksh[Bc][Br];
|
||||||
shared vec4 Qf[Br][D / 4];
|
shared vec4 Qf[Br][HSK / 4];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||||
|
|
@ -53,18 +54,18 @@ void main() {
|
||||||
|
|
||||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (D / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t r = (idx + tid) / (D / 4);
|
uint32_t r = (idx + tid) / (HSK / 4);
|
||||||
if (r < Br && d < D / 4 &&
|
if (r < Br && d < HSK / 4 &&
|
||||||
i * Br + r < N) {
|
i * Br + r < N) {
|
||||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
vec4 Of[Br][D_per_thread / 4];
|
vec4 Of[Br][HSV_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] = vec4(0.0);
|
Of[r][d] = vec4(0.0);
|
||||||
}
|
}
|
||||||
|
|
@ -116,7 +117,7 @@ void main() {
|
||||||
|
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
|
@ -195,14 +196,14 @@ void main() {
|
||||||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] = eMf[r] * Of[r][d];
|
Of[r][d] = eMf[r] * Of[r][d];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
|
@ -259,7 +260,7 @@ void main() {
|
||||||
Lf[r] = tmpsh[d_tid];
|
Lf[r] = tmpsh[d_tid];
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
|
|
||||||
Of[r][d] = eMf * Of[r][d];
|
Of[r][d] = eMf * Of[r][d];
|
||||||
tmpshv4[tid] = Of[r][d];
|
tmpshv4[tid] = Of[r][d];
|
||||||
|
|
@ -281,11 +282,11 @@ void main() {
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
// division by L. Store the intermediate O value and per-row m and L values.
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
if (p.k_num > 1) {
|
if (p.k_num > 1) {
|
||||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
|
@ -293,7 +294,7 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||||
|
|
@ -309,18 +310,18 @@ void main() {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = 1.0 / Lf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] *= Lfrcp[r];
|
Of[r][d] *= Lfrcp[r];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||||
|
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
|
@ -330,9 +331,9 @@ void main() {
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (i * Br + r < N) {
|
if (i * Br + r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||||
layout (constant_id = 1) const uint32_t Br = 1;
|
layout (constant_id = 1) const uint32_t Br = 1;
|
||||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||||
layout (constant_id = 3) const uint32_t D = 32;
|
layout (constant_id = 3) const uint32_t HSK = 32;
|
||||||
layout (constant_id = 4) const uint32_t Clamp = 0;
|
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||||
|
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint32_t N;
|
uint32_t N;
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,9 @@
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
#include "flash_attn_base.comp"
|
#include "flash_attn_base.comp"
|
||||||
|
|
||||||
const uint32_t D_per_thread = D / D_split;
|
const uint32_t HSK_per_thread = HSK / D_split;
|
||||||
|
const uint32_t HSV_per_thread = HSV / D_split;
|
||||||
|
|
||||||
const uint32_t row_split = 4;
|
const uint32_t row_split = 4;
|
||||||
const uint32_t rows_per_thread = Br / row_split;
|
const uint32_t rows_per_thread = Br / row_split;
|
||||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||||
|
|
@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
{
|
{
|
||||||
uint32_t offset = (iq2 + r) * D + c;
|
uint32_t offset = (iq2 + r) * HSV + c;
|
||||||
data_o[o_offset + offset] = D_TYPE(elem);
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
return elem;
|
return elem;
|
||||||
}
|
}
|
||||||
|
|
@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
|
||||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||||
|
|
||||||
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||||
shared f16vec4 Qf[Br * qstride];
|
shared f16vec4 Qf[Br * qstride];
|
||||||
|
|
||||||
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||||
|
|
||||||
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||||
shared f16vec4 ksh[Bc * kshstride];
|
shared f16vec4 ksh[Bc * kshstride];
|
||||||
|
|
||||||
shared float slope[Br];
|
shared float slope[Br];
|
||||||
|
|
@ -74,18 +76,18 @@ void main() {
|
||||||
|
|
||||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (D / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t r = (idx + tid) / (D / 4);
|
uint32_t r = (idx + tid) / (HSK / 4);
|
||||||
if (r < Br && d < D / 4 &&
|
if (r < Br && d < HSK / 4 &&
|
||||||
i * Br + r < N) {
|
i * Br + r < N) {
|
||||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Of[r][d] = ACC_TYPEV4(0.0);
|
Of[r][d] = ACC_TYPEV4(0.0);
|
||||||
}
|
}
|
||||||
|
|
@ -131,10 +133,10 @@ void main() {
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (D / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t c = (idx + tid) / (D / 4);
|
uint32_t c = (idx + tid) / (HSK / 4);
|
||||||
if (c < Bc && d < D / 4) {
|
if (c < Bc && d < HSK / 4) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
|
@ -149,14 +151,14 @@ void main() {
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||||
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||||
|
|
||||||
for (uint32_t d = 0; d < D / 16; ++d) {
|
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||||
|
|
@ -206,7 +208,7 @@ void main() {
|
||||||
eMf[r] = exp(Moldf - Mf[r]);
|
eMf[r] = exp(Moldf - Mf[r]);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||||
}
|
}
|
||||||
|
|
@ -221,7 +223,7 @@ void main() {
|
||||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||||
Lf[r] += Pf[r];
|
Lf[r] += Pf[r];
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
|
@ -284,7 +286,7 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
|
|
||||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||||
tmpshv4[tid] = Of[r][d];
|
tmpshv4[tid] = Of[r][d];
|
||||||
|
|
@ -304,11 +306,11 @@ void main() {
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
// division by L. Store the intermediate O value and per-row m and L values.
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
if (p.k_num > 1) {
|
if (p.k_num > 1) {
|
||||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (tile_row(r) < N) {
|
if (tile_row(r) < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
|
@ -316,7 +318,7 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (tile_row(r) < N) {
|
if (tile_row(r) < N) {
|
||||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||||
|
|
@ -332,18 +334,18 @@ void main() {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = 1.0 / Lf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||||
|
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (tile_row(r) < N) {
|
if (tile_row(r) < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
|
@ -353,9 +355,9 @@ void main() {
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (i * Br + tile_row(r) < N) {
|
if (i * Br + tile_row(r) < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
{
|
{
|
||||||
if (r < N && c < D) {
|
if (r < N && c < HSV) {
|
||||||
uint32_t offset = (iq2 + r) * D + c;
|
uint32_t offset = (iq2 + r) * HSV + c;
|
||||||
data_o[o_offset + offset] = D_TYPE(elem);
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
}
|
}
|
||||||
return elem;
|
return elem;
|
||||||
|
|
@ -86,9 +86,9 @@ void main() {
|
||||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
||||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
||||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
|
||||||
|
|
||||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||||
|
|
@ -104,16 +104,16 @@ void main() {
|
||||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||||
|
|
||||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||||
|
|
||||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||||
|
|
||||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||||
Qf16 *= float16_t(p.scale);
|
Qf16 *= float16_t(p.scale);
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||||
|
|
||||||
|
|
@ -140,10 +140,10 @@ void main() {
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||||
|
|
||||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||||
S = coopMatMulAdd(Qf16, K_T, S);
|
S = coopMatMulAdd(Qf16, K_T, S);
|
||||||
|
|
||||||
if (p.logit_softcap != 0.0f) {
|
if (p.logit_softcap != 0.0f) {
|
||||||
|
|
@ -208,42 +208,42 @@ void main() {
|
||||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||||
|
|
||||||
L = eM*L + rowsum;
|
L = eM*L + rowsum;
|
||||||
|
|
||||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||||
// across the row
|
// across the row
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||||
|
|
||||||
// resize eM by using smear/reduce
|
// resize eM by using smear/reduce
|
||||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||||
|
|
||||||
// multiply with fp16 accumulation, then add to O.
|
// multiply with fp16 accumulation, then add to O.
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||||
PV = coopMatMulAdd(P_A, V, PV);
|
PV = coopMatMulAdd(P_A, V, PV);
|
||||||
|
|
||||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
// division by L. Store the intermediate O value and per-row m and L values.
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
if (p.k_num > 1) {
|
if (p.k_num > 1) {
|
||||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||||
|
|
||||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||||
|
|
||||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||||
|
|
||||||
// resize L by using smear/reduce
|
// resize L by using smear/reduce
|
||||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||||
|
|
@ -255,18 +255,18 @@ void main() {
|
||||||
|
|
||||||
O = Ldiag*O;
|
O = Ldiag*O;
|
||||||
|
|
||||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||||
|
|
||||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||||
} else {
|
} else {
|
||||||
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
||||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
|
||||||
|
|
||||||
// permute dimensions
|
// permute dimensions
|
||||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||||
|
|
||||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3674,7 +3674,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
if (mask) {
|
if (mask) {
|
||||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(ggml_is_3d(mask));
|
|
||||||
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
||||||
|
|
@ -4704,12 +4703,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
|
||||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||||
|
|
||||||
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
|
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
|
||||||
|
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
|
|
@ -6051,13 +6050,28 @@ static void ggml_compute_backward(
|
||||||
}
|
}
|
||||||
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GLU: {
|
||||||
|
switch (ggml_get_glu_op(tensor)) {
|
||||||
|
case GGML_GLU_OP_SWIGLU: {
|
||||||
|
if (src0_needs_grads) {
|
||||||
|
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
||||||
|
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
|
||||||
|
}
|
||||||
|
if (src1_needs_grads) {
|
||||||
|
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default: {
|
||||||
|
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
|
||||||
|
} //break;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_OP_NONE: {
|
case GGML_OP_NONE: {
|
||||||
// noop
|
// noop
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_COUNT:
|
case GGML_OP_COUNT:
|
||||||
default: {
|
default: {
|
||||||
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
} //break;
|
} //break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -714,8 +714,8 @@ class GGUFWriter:
|
||||||
def add_clamp_kqv(self, value: float) -> None:
|
def add_clamp_kqv(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_shared_kv_layers(self, value: float) -> None:
|
def add_shared_kv_layers(self, value: int) -> None:
|
||||||
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
||||||
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
||||||
|
|
|
||||||
|
|
@ -245,9 +245,18 @@ class SpecialVocab:
|
||||||
if not tokenizer_config:
|
if not tokenizer_config:
|
||||||
return True
|
return True
|
||||||
chat_template_alt = None
|
chat_template_alt = None
|
||||||
chat_template_file = path / 'chat_template.json'
|
chat_template_json = path / 'chat_template.json'
|
||||||
if chat_template_file.is_file():
|
chat_template_jinja = path / 'chat_template.jinja'
|
||||||
with open(chat_template_file, encoding = 'utf-8') as f:
|
if chat_template_jinja.is_file():
|
||||||
|
with open(chat_template_jinja, encoding = 'utf-8') as f:
|
||||||
|
chat_template_alt = f.read()
|
||||||
|
if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
|
||||||
|
chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
|
||||||
|
for template_path in additional_templates:
|
||||||
|
with open(template_path, encoding = 'utf-8') as fp:
|
||||||
|
chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
|
||||||
|
elif chat_template_json.is_file():
|
||||||
|
with open(chat_template_json, encoding = 'utf-8') as f:
|
||||||
chat_template_alt = json.load(f).get('chat_template')
|
chat_template_alt = json.load(f).get('chat_template')
|
||||||
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
||||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,6 @@ while read c; do
|
||||||
src/ggml-cpu/* \
|
src/ggml-cpu/* \
|
||||||
src/ggml-cuda/* \
|
src/ggml-cuda/* \
|
||||||
src/ggml-hip/* \
|
src/ggml-hip/* \
|
||||||
src/ggml-kompute/* \
|
|
||||||
src/ggml-metal/* \
|
src/ggml-metal/* \
|
||||||
src/ggml-musa/* \
|
src/ggml-musa/* \
|
||||||
src/ggml-opencl/* \
|
src/ggml-opencl/* \
|
||||||
|
|
@ -141,7 +140,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||||
# src/ggml-cpu/* -> ggml/src/ggml-cpu/*
|
# src/ggml-cpu/* -> ggml/src/ggml-cpu/*
|
||||||
# src/ggml-cuda/* -> ggml/src/ggml-cuda/*
|
# src/ggml-cuda/* -> ggml/src/ggml-cuda/*
|
||||||
# src/ggml-hip/* -> ggml/src/ggml-hip/*
|
# src/ggml-hip/* -> ggml/src/ggml-hip/*
|
||||||
# src/ggml-kompute/* -> ggml/src/ggml-kompute/*
|
|
||||||
# src/ggml-metal/* -> ggml/src/ggml-metal/*
|
# src/ggml-metal/* -> ggml/src/ggml-metal/*
|
||||||
# src/ggml-musa/* -> ggml/src/ggml-musa/*
|
# src/ggml-musa/* -> ggml/src/ggml-musa/*
|
||||||
# src/ggml-opencl/* -> ggml/src/ggml-opencl/*
|
# src/ggml-opencl/* -> ggml/src/ggml-opencl/*
|
||||||
|
|
@ -174,7 +172,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
|
-e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \
|
-e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \
|
-e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \
|
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \
|
-e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \
|
-e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \
|
||||||
-e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \
|
-e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
||||||
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
|
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
|
||||||
cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/
|
cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/
|
||||||
cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/
|
cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/
|
||||||
cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/
|
|
||||||
cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/
|
cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/
|
||||||
cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/
|
cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/
|
||||||
cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/
|
cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/
|
||||||
|
|
|
||||||
|
|
@ -281,20 +281,23 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
|
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
|
|
||||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
}
|
|
||||||
|
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
|
||||||
|
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||||
|
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||||
|
|
||||||
if (self_kq_mask_swa) {
|
|
||||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_ASSERT(cross_kq_mask);
|
GGML_ASSERT(cross_kq_mask);
|
||||||
|
|
@ -332,7 +335,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
||||||
|
GGML_UNUSED(ubatch);
|
||||||
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
||||||
float f_one = 1.0f;
|
float f_one = 1.0f;
|
||||||
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
||||||
|
|
@ -1155,8 +1159,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
|
||||||
|
|
||||||
const auto n_kv = mctx_cur->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||||
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
|
@ -1187,8 +1193,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & k_idxs = inp->get_k_idxs();
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
const auto & v_idxs = inp->get_v_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
@ -1247,11 +1256,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
// optionally store to KV cache
|
// optionally store to KV cache
|
||||||
if (k_cur) {
|
if (k_cur) {
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (v_cur) {
|
if (v_cur) {
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
|
|
@ -1343,8 +1356,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||||
{
|
{
|
||||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||||
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
|
@ -1355,8 +1370,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||||
|
|
||||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||||
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
|
|
||||||
|
|
@ -249,8 +249,14 @@ public:
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
|
||||||
|
|
@ -274,9 +280,19 @@ public:
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
||||||
|
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||||
|
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||||
|
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
|
@ -309,7 +325,7 @@ public:
|
||||||
llm_graph_input_one() {}
|
llm_graph_input_one() {}
|
||||||
virtual ~llm_graph_input_one() = default;
|
virtual ~llm_graph_input_one() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch *) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * one = nullptr; // F32
|
ggml_tensor * one = nullptr; // F32
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (heads_base.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||||
if (heads_swa.empty()) {
|
if (sinfos_swa.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(sinfos_base.size() == sinfos_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// if it fails, try equal split
|
// if it fails, try equal split
|
||||||
|
|
@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (heads_base.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||||
if (heads_swa.empty()) {
|
if (sinfos_swa.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(sinfos_base.size() == sinfos_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// TODO: if we fail again, we should attempt different splitting strategies
|
// TODO: if we fail again, we should attempt different splitting strategies
|
||||||
|
|
@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
slot_info_vec_t sinfos_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
slot_info_vec_t sinfos_swa,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,8 @@ private:
|
||||||
|
|
||||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||||
|
|
||||||
|
|
@ -90,8 +92,8 @@ public:
|
||||||
// used to create a batch processing context from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_kv_cache_unified_iswa_context(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
slot_info_vec_t sinfos_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
slot_info_vec_t sinfos_swa,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_iswa_context();
|
virtual ~llama_kv_cache_unified_iswa_context();
|
||||||
|
|
|
||||||
|
|
@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||||
|
|
||||||
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
||||||
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
||||||
|
|
||||||
|
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||||
|
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||||
|
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::clear(bool data) {
|
void llama_kv_cache_unified::clear(bool data) {
|
||||||
|
|
@ -353,13 +360,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads = prepare(ubatches);
|
auto sinfos = prepare(ubatches);
|
||||||
if (heads.empty()) {
|
if (sinfos.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(
|
return std::make_unique<llama_kv_cache_unified_context>(
|
||||||
this, std::move(heads), std::move(ubatches));
|
this, std::move(sinfos), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
|
@ -402,12 +409,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
llama_kv_cache_unified::ubatch_heads res;
|
llama_kv_cache_unified::slot_info_vec_t res;
|
||||||
|
|
||||||
struct state {
|
struct state {
|
||||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||||
uint32_t head_new; // new position of the head, after placing the ubatch
|
|
||||||
|
slot_info sinfo; // slot info for the ubatch
|
||||||
|
|
||||||
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
||||||
};
|
};
|
||||||
|
|
@ -418,26 +426,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
|
||||||
bool success = true;
|
bool success = true;
|
||||||
|
|
||||||
for (const auto & ubatch : ubatches) {
|
for (const auto & ubatch : ubatches) {
|
||||||
|
// non-continuous slots require support for ggml_set_rows()
|
||||||
|
const bool cont = supports_set_rows ? false : true;
|
||||||
|
|
||||||
// only find a suitable slot for the ubatch. don't modify the cells yet
|
// only find a suitable slot for the ubatch. don't modify the cells yet
|
||||||
const int32_t head_new = find_slot(ubatch);
|
const auto sinfo_new = find_slot(ubatch, cont);
|
||||||
if (head_new < 0) {
|
if (sinfo_new.empty()) {
|
||||||
success = false;
|
success = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// remeber the position that we found
|
// remeber the position that we found
|
||||||
res.push_back(head_new);
|
res.push_back(sinfo_new);
|
||||||
|
|
||||||
// store the old state of the cells in the recovery stack
|
// store the old state of the cells in the recovery stack
|
||||||
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
|
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
|
||||||
|
|
||||||
// now emplace the ubatch
|
// now emplace the ubatch
|
||||||
apply_ubatch(head_new, ubatch);
|
apply_ubatch(sinfo_new, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate backwards and restore the cells to their original state
|
// iterate backwards and restore the cells to their original state
|
||||||
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
||||||
cells.set(it->head_new, it->cells);
|
cells.set(it->sinfo.idxs, it->cells);
|
||||||
head = it->head_old;
|
head = it->head_old;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -539,7 +550,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||||
return updated;
|
return updated;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
uint32_t head_cur = this->head;
|
uint32_t head_cur = this->head;
|
||||||
|
|
@ -552,7 +563,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
|
|
||||||
if (n_tokens > cells.size()) {
|
if (n_tokens > cells.size()) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||||
return -1;
|
return { };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (debug > 0) {
|
if (debug > 0) {
|
||||||
|
|
@ -615,15 +626,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
|
|
||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
|
|
||||||
|
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
|
||||||
|
// for non-continuous slots, we test the tokens one by one
|
||||||
|
const uint32_t n_test = cont ? n_tokens : 1;
|
||||||
|
|
||||||
|
slot_info res;
|
||||||
|
|
||||||
|
auto & idxs = res.idxs;
|
||||||
|
|
||||||
|
idxs.reserve(n_tokens);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (head_cur + n_tokens > cells.size()) {
|
if (head_cur + n_test > cells.size()) {
|
||||||
n_tested += cells.size() - head_cur;
|
n_tested += cells.size() - head_cur;
|
||||||
head_cur = 0;
|
head_cur = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool found = true;
|
for (uint32_t i = 0; i < n_test; i++) {
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
const auto idx = head_cur;
|
||||||
|
|
||||||
//const llama_pos pos = ubatch.pos[i];
|
//const llama_pos pos = ubatch.pos[i];
|
||||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
|
|
||||||
|
|
@ -633,19 +655,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||||
// - mask SWA, using current max pos for that sequence in the cache
|
// - mask SWA, using current max pos for that sequence in the cache
|
||||||
// always insert in the cell with minimum pos
|
// always insert in the cell with minimum pos
|
||||||
bool can_use = cells.is_empty(head_cur + i);
|
bool can_use = cells.is_empty(idx);
|
||||||
|
|
||||||
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
if (!can_use && cells.seq_count(idx) == 1) {
|
||||||
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
const llama_pos pos_cell = cells.pos_get(idx);
|
||||||
|
|
||||||
// (disabled) causal mask
|
// (disabled) causal mask
|
||||||
// note: it's better to purge any "future" tokens beforehand
|
// note: it's better to purge any "future" tokens beforehand
|
||||||
//if (cells.seq_has(head_cur + i, seq_id)) {
|
//if (cells.seq_has(idx, seq_id)) {
|
||||||
// can_use = pos_cell >= pos;
|
// can_use = pos_cell >= pos;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
if (!can_use) {
|
if (!can_use) {
|
||||||
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||||
|
|
||||||
// SWA mask
|
// SWA mask
|
||||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||||
|
|
@ -654,28 +676,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!can_use) {
|
head_cur++;
|
||||||
found = false;
|
n_tested++;
|
||||||
head_cur += i + 1;
|
|
||||||
n_tested += i + 1;
|
if (can_use) {
|
||||||
|
idxs.push_back(idx);
|
||||||
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (found) {
|
if (idxs.size() == n_tokens) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cont) {
|
||||||
|
idxs.clear();
|
||||||
|
}
|
||||||
|
|
||||||
if (n_tested >= cells.size()) {
|
if (n_tested >= cells.size()) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return -1;
|
return { };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return head_cur;
|
// we didn't find a suitable slot - return empty result
|
||||||
|
if (idxs.size() < n_tokens) {
|
||||||
|
res.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||||
// for non-SWA cache, this would be always empty
|
// for non-SWA cache, this would be always empty
|
||||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||||
|
|
@ -683,22 +716,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
||||||
seq_pos_max_rm[s] = -1;
|
seq_pos_max_rm[s] = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
assert(ubatch.n_tokens == sinfo.idxs.size());
|
||||||
if (!cells.is_empty(head_cur + i)) {
|
|
||||||
assert(cells.seq_count(head_cur + i) == 1);
|
|
||||||
|
|
||||||
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
const llama_pos pos = cells.pos_get(head_cur + i);
|
const auto idx = sinfo.idxs.at(i);
|
||||||
|
|
||||||
|
if (!cells.is_empty(idx)) {
|
||||||
|
assert(cells.seq_count(idx) == 1);
|
||||||
|
|
||||||
|
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||||
|
const llama_pos pos = cells.pos_get(idx);
|
||||||
|
|
||||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||||
|
|
||||||
cells.rm(head_cur + i);
|
cells.rm(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
cells.pos_set(idx, ubatch.pos[i]);
|
||||||
|
|
||||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -719,7 +756,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
||||||
}
|
}
|
||||||
|
|
||||||
// move the head at the end of the slot
|
// move the head at the end of the slot
|
||||||
head = head_cur + ubatch.n_tokens;
|
head = sinfo.idxs.back() + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::get_can_shift() const {
|
bool llama_kv_cache_unified::get_can_shift() const {
|
||||||
|
|
@ -772,47 +809,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * k = layers[ikv].k;
|
auto * k = layers[ikv].k;
|
||||||
|
|
||||||
|
const int64_t n_embd_k_gqa = k->ne[0];
|
||||||
const int64_t n_tokens = k_cur->ne[2];
|
const int64_t n_tokens = k_cur->ne[2];
|
||||||
|
|
||||||
|
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||||
|
|
||||||
|
if (k_idxs && supports_set_rows) {
|
||||||
|
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
|
// will be removed when ggml_set_rows() is adopted by all backends
|
||||||
|
|
||||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||||
n_tokens*hparams.n_embd_k_gqa(il),
|
n_tokens*n_embd_k_gqa,
|
||||||
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
|
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
||||||
|
|
||||||
return ggml_cpy(ctx, k_cur, k_view);
|
return ggml_cpy(ctx, k_cur, k_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * v = layers[ikv].v;
|
auto * v = layers[ikv].v;
|
||||||
|
|
||||||
|
const int64_t n_embd_v_gqa = v->ne[0];
|
||||||
const int64_t n_tokens = v_cur->ne[2];
|
const int64_t n_tokens = v_cur->ne[2];
|
||||||
|
|
||||||
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||||
|
|
||||||
|
if (v_idxs && supports_set_rows) {
|
||||||
|
if (!v_trans) {
|
||||||
|
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// the row becomes a single element
|
||||||
|
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
||||||
|
|
||||||
|
// note: the V cache is transposed when not using flash attention
|
||||||
|
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
||||||
|
|
||||||
|
// note: we can be more explicit here at the cost of extra cont
|
||||||
|
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
|
||||||
|
//v_cur = ggml_transpose(ctx, v_cur);
|
||||||
|
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||||
|
|
||||||
|
// we broadcast the KV indices n_embd_v_gqa times
|
||||||
|
// v [1, n_kv, n_embd_v_gqa]
|
||||||
|
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||||
|
// v_idxs [n_tokens, 1, 1]
|
||||||
|
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
|
// will be removed when ggml_set_rows() is adopted by all backends
|
||||||
|
|
||||||
ggml_tensor * v_view = nullptr;
|
ggml_tensor * v_view = nullptr;
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
v_view = ggml_view_1d(ctx, v,
|
v_view = ggml_view_1d(ctx, v,
|
||||||
n_tokens*hparams.n_embd_v_gqa(il),
|
n_tokens*n_embd_v_gqa,
|
||||||
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
|
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
|
||||||
} else {
|
} else {
|
||||||
// note: the V cache is transposed when not using flash attention
|
|
||||||
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
|
||||||
(v->ne[1])*ggml_element_size(v),
|
|
||||||
(head_cur)*ggml_element_size(v));
|
|
||||||
|
|
||||||
v_cur = ggml_transpose(ctx, v_cur);
|
v_cur = ggml_transpose(ctx, v_cur);
|
||||||
|
|
||||||
|
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
|
||||||
|
(v->ne[1] )*ggml_element_size(v),
|
||||||
|
(sinfo.head())*ggml_element_size(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_cpy(ctx, v_cur, v_view);
|
return ggml_cpy(ctx, v_cur, v_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
|
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||||
|
|
||||||
|
ggml_set_input(k_idxs);
|
||||||
|
|
||||||
|
return k_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
|
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||||
|
|
||||||
|
ggml_set_input(v_idxs);
|
||||||
|
|
||||||
|
return v_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||||
|
data[i] = sinfo.idxs.at(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||||
|
data[i] = sinfo.idxs.at(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
const uint32_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
|
@ -1552,13 +1675,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||||
ubatch.seq_id[i] = &dest_seq_id;
|
ubatch.seq_id[i] = &dest_seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto head_cur = find_slot(ubatch);
|
const auto sinfo = find_slot(ubatch, true);
|
||||||
if (head_cur < 0) {
|
if (sinfo.empty()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
apply_ubatch(head_cur, ubatch);
|
apply_ubatch(sinfo, ubatch);
|
||||||
|
|
||||||
|
const auto head_cur = sinfo.head();
|
||||||
|
|
||||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||||
head = head_cur;
|
head = head_cur;
|
||||||
|
|
@ -1744,7 +1869,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
n_kv = kv->get_size();
|
n_kv = kv->get_size();
|
||||||
head = 0;
|
|
||||||
|
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
||||||
|
sinfos.resize(1);
|
||||||
|
sinfos[0].idxs.resize(1);
|
||||||
|
sinfos[0].idxs[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
|
|
@ -1759,8 +1888,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_kv_cache_unified::ubatch_heads heads,
|
llama_kv_cache_unified::slot_info_vec_t sinfos,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
|
|
@ -1768,7 +1897,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
bool llama_kv_cache_unified_context::next() {
|
bool llama_kv_cache_unified_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_cur >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1785,10 +1914,9 @@ bool llama_kv_cache_unified_context::apply() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
||||||
|
|
||||||
n_kv = kv->get_n_kv();
|
n_kv = kv->get_n_kv();
|
||||||
head = heads[i_next];
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -1800,7 +1928,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||||
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_cur];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||||
|
|
@ -1815,18 +1943,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
||||||
return kv->get_v(ctx, il, n_kv);
|
return kv->get_v(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, il, head);
|
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, il, head);
|
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
return kv->build_input_k_idxs(ctx, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
return kv->build_input_v_idxs(ctx, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
kv->set_input_k_shift(dst);
|
kv->set_input_k_shift(dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,6 @@ public:
|
||||||
// this callback is used to filter out layers that should not be included in the cache
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
using ubatch_heads = std::vector<uint32_t>;
|
|
||||||
|
|
||||||
struct defrag_info {
|
struct defrag_info {
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return ids.empty();
|
return ids.empty();
|
||||||
|
|
@ -37,6 +35,32 @@ public:
|
||||||
std::vector<uint32_t> ids;
|
std::vector<uint32_t> ids;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
||||||
|
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
||||||
|
struct slot_info {
|
||||||
|
// data for ggml_set_rows
|
||||||
|
using idx_vec_t = std::vector<uint32_t>;
|
||||||
|
|
||||||
|
idx_vec_t idxs;
|
||||||
|
|
||||||
|
uint32_t head() const {
|
||||||
|
return idxs.at(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return idxs.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
idxs.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
//std::vector<idx_vec_t> seq_idxs;
|
||||||
|
};
|
||||||
|
|
||||||
|
using slot_info_vec_t = std::vector<slot_info>;
|
||||||
|
|
||||||
llama_kv_cache_unified(
|
llama_kv_cache_unified(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
layer_filter_cb && filter,
|
layer_filter_cb && filter,
|
||||||
|
|
@ -102,30 +126,37 @@ public:
|
||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// preparation API
|
// preparation API
|
||||||
//
|
//
|
||||||
|
|
||||||
// find places for the provided ubatches in the cache, returns the head locations
|
// find places for the provided ubatches in the cache, returns the slot infos
|
||||||
// return empty vector on failure
|
// return empty vector on failure
|
||||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||||
|
|
||||||
// return the cell position where we can insert the ubatch
|
// find a slot of kv cells that can hold the ubatch
|
||||||
// return -1 on failure to find a contiguous slot of kv cells
|
// if cont == true, then the slot must be continuous
|
||||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
// return empty slot_info on failure
|
||||||
|
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||||
|
|
||||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||||
|
|
||||||
//
|
//
|
||||||
// set_input API
|
// input API
|
||||||
//
|
//
|
||||||
|
|
||||||
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
|
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_k_shift (ggml_tensor * dst) const;
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
@ -157,8 +188,13 @@ private:
|
||||||
// SWA
|
// SWA
|
||||||
const uint32_t n_swa = 0;
|
const uint32_t n_swa = 0;
|
||||||
|
|
||||||
|
// env: LLAMA_KV_CACHE_DEBUG
|
||||||
int debug = 0;
|
int debug = 0;
|
||||||
|
|
||||||
|
// env: LLAMA_SET_ROWS (temporary)
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
||||||
|
int supports_set_rows = false;
|
||||||
|
|
||||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
|
|
@ -211,7 +247,7 @@ private:
|
||||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// some shorthands
|
// some shorthands
|
||||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
|
|
@ -231,7 +267,7 @@ public:
|
||||||
// used to create a batch procesing context from a batch
|
// used to create a batch procesing context from a batch
|
||||||
llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
ubatch_heads heads,
|
slot_info_vec_t sinfos,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_context();
|
virtual ~llama_kv_cache_unified_context();
|
||||||
|
|
@ -257,11 +293,16 @@ public:
|
||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
void set_input_k_shift (ggml_tensor * dst) const;
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
|
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
|
@ -283,10 +324,10 @@ private:
|
||||||
// batch processing context
|
// batch processing context
|
||||||
//
|
//
|
||||||
|
|
||||||
// the index of the next ubatch to process
|
// the index of the cur ubatch to process
|
||||||
size_t i_next = 0;
|
size_t i_cur = 0;
|
||||||
|
|
||||||
ubatch_heads heads;
|
slot_info_vec_t sinfos;
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
|
@ -297,7 +338,4 @@ private:
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// as the cache gets filled, the benefit from this heuristic disappears
|
// as the cache gets filled, the benefit from this heuristic disappears
|
||||||
int32_t n_kv;
|
int32_t n_kv;
|
||||||
|
|
||||||
// the beginning of the current slot in which the ubatch will be inserted
|
|
||||||
int32_t head;
|
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -105,10 +105,30 @@ public:
|
||||||
res.resize(n);
|
res.resize(n);
|
||||||
|
|
||||||
for (uint32_t j = 0; j < n; ++j) {
|
for (uint32_t j = 0; j < n; ++j) {
|
||||||
res.pos[j] = pos[i + j];
|
const auto idx = i + j;
|
||||||
res.seq[j] = seq[i + j];
|
|
||||||
|
|
||||||
assert(shift[i + j] == 0);
|
res.pos[j] = pos[idx];
|
||||||
|
res.seq[j] = seq[idx];
|
||||||
|
|
||||||
|
assert(shift[idx] == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||||
|
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
||||||
|
llama_kv_cells_unified res;
|
||||||
|
|
||||||
|
res.resize(idxs.size());
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < idxs.size(); ++j) {
|
||||||
|
const auto idx = idxs[j];
|
||||||
|
|
||||||
|
res.pos[j] = pos[idx];
|
||||||
|
res.seq[j] = seq[idx];
|
||||||
|
|
||||||
|
assert(shift[idx] == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|
@ -119,26 +139,58 @@ public:
|
||||||
assert(i + other.pos.size() <= pos.size());
|
assert(i + other.pos.size() <= pos.size());
|
||||||
|
|
||||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
const auto idx = i + j;
|
||||||
|
|
||||||
|
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||||
used.insert(i + j);
|
used.insert(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||||
used.erase(i + j);
|
used.erase(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos[i + j] != -1) {
|
if (pos[idx] != -1) {
|
||||||
seq_pos_rm(i + j);
|
seq_pos_rm(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
pos[i + j] = other.pos[j];
|
pos[idx] = other.pos[j];
|
||||||
seq[i + j] = other.seq[j];
|
seq[idx] = other.seq[j];
|
||||||
|
|
||||||
if (pos[i + j] != -1) {
|
if (pos[idx] != -1) {
|
||||||
seq_pos_add(i + j);
|
seq_pos_add(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(shift[i + j] == 0);
|
assert(shift[idx] == 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||||
|
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
||||||
|
assert(idxs.size() == other.pos.size());
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||||
|
const auto idx = idxs[j];
|
||||||
|
|
||||||
|
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||||
|
used.insert(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||||
|
used.erase(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[idx] != -1) {
|
||||||
|
seq_pos_rm(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
pos[idx] = other.pos[j];
|
||||||
|
seq[idx] = other.seq[j];
|
||||||
|
|
||||||
|
if (pos[idx] != -1) {
|
||||||
|
seq_pos_add(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(shift[idx] == 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -195,11 +195,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
|
|
||||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
slot_info_vec_t sinfos_attn,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,8 @@ private:
|
||||||
|
|
||||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
|
|
||||||
// init failure
|
// init failure
|
||||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||||
|
|
||||||
|
|
@ -107,7 +109,7 @@ public:
|
||||||
// init success
|
// init success
|
||||||
llama_memory_hybrid_context(
|
llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
slot_info_vec_t sinfos_attn,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
~llama_memory_hybrid_context() = default;
|
~llama_memory_hybrid_context() = default;
|
||||||
|
|
|
||||||
|
|
@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
|
||||||
if (v & 1) {
|
if (v & 1) {
|
||||||
auto ne = ne_a; ne[0] *= 3;
|
auto ne = ne_a; ne[0] *= 3;
|
||||||
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_param(a);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
ggml_set_name(a, "view_of_a");
|
ggml_set_name(a, "view_of_a");
|
||||||
|
|
||||||
b = ggml_new_tensor(ctx, type, 4, ne.data());
|
b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_param(b);
|
||||||
ggml_set_name(b, "b");
|
ggml_set_name(b, "b");
|
||||||
|
|
||||||
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
|
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
|
||||||
ggml_set_name(a, "view_of_b");
|
ggml_set_name(a, "view_of_b");
|
||||||
} else {
|
} else {
|
||||||
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_set_param(a);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_set_param(b);
|
||||||
ggml_set_name(b, "b");
|
ggml_set_name(b, "b");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3637,7 +3641,7 @@ struct test_flash_attn_ext : public test_case {
|
||||||
|
|
||||||
ggml_tensor * m = nullptr;
|
ggml_tensor * m = nullptr;
|
||||||
if (mask) {
|
if (mask) {
|
||||||
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
|
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
|
||||||
ggml_set_name(m, "m");
|
ggml_set_name(m, "m");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4751,7 +4755,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
|
||||||
|
|
||||||
if (ne0 <= 32 && ne1 <= 32) {
|
if (ne0 <= 32 && ne1 <= 32) {
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -4932,6 +4936,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||||
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,3 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#ifdef GGML_USE_KOMPUTE
|
|
||||||
#include "ggml-kompute.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
int main(void) {}
|
int main(void) {}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue