Compare commits

...

6 Commits

Author SHA1 Message Date
Aman Gupta 116f488d3b
Merge 459b75b3cd into 2634ed207a 2026-02-01 11:24:07 +01:00
Neo Zhang 2634ed207a
create test.sh to enhance the parameters for testing, update the guide, rm useless script (#19243) 2026-02-01 18:24:00 +08:00
Aman Gupta 459b75b3cd templatize multi_token_path 2026-01-31 05:56:47 +01:00
Aman Gupta 3183b72b16 Fix perf issue on ampere. Use mmvf mm-id only for non-nvidia GPUs 2026-01-28 09:36:56 +01:00
Aman Gupta e708ba6adb add mmvq too 2026-01-28 09:36:56 +01:00
Aman Gupta 3dadad6f62 CUDA: use mmvq for mul-mat-id for small batch sizes 2026-01-28 09:36:56 +01:00
10 changed files with 373 additions and 180 deletions

View File

@ -119,7 +119,7 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
*Notes:*
- **Memory**
- The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-cli`.
- The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-completion`.
- Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU.
- **Execution Unit (EU)**
@ -423,16 +423,12 @@ Choose one of following methods to run.
- Use device 0:
```sh
./examples/sycl/run-llama2.sh 0
# OR
./examples/sycl/run-llama3.sh 0
./examples/sycl/test.sh -mg 0
```
- Use multiple devices:
```sh
./examples/sycl/run-llama2.sh
# OR
./examples/sycl/run-llama3.sh
./examples/sycl/test.sh
```
2. Command line
@ -455,13 +451,13 @@ Examples:
- Use device 0:
```sh
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 --mmap
```
- Use multiple devices:
```sh
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer --mmap
```
*Notes:*
@ -577,13 +573,13 @@ Or, use CMake presets to build:
```sh
cmake --preset x64-windows-sycl-release
cmake --build build-x64-windows-sycl-release -j --target llama-cli
cmake --build build-x64-windows-sycl-release -j --target llama-completion
cmake -DGGML_SYCL_F16=ON --preset x64-windows-sycl-release
cmake --build build-x64-windows-sycl-release -j --target llama-cli
cmake --build build-x64-windows-sycl-release -j --target llama-completion
cmake --preset x64-windows-sycl-debug
cmake --build build-x64-windows-sycl-debug -j --target llama-cli
cmake --build build-x64-windows-sycl-debug -j --target llama-completion
```
#### 3. Visual Studio
@ -608,7 +604,7 @@ You can use Visual Studio to open the `llama.cpp` folder directly as a CMake pro
- For a minimal experimental setup, you can build only the inference executable using:
```Powershell
cmake --build build --config Release -j --target llama-cli
cmake --build build --config Release -j --target llama-completion
```
##### - Generating a Visual Studio Solution
@ -714,13 +710,7 @@ Choose one of following methods to run.
1. Script
```
examples\sycl\win-run-llama-2.bat
```
or
```
examples\sycl\win-run-llama-3.bat
examples\sycl\win-test.bat
```
2. Command line
@ -744,13 +734,13 @@ Examples:
- Use device 0:
```
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0
build\bin\llama-completion.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 --mmap
```
- Use multiple devices:
```
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer
build\bin\llama-completion.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer --mmap
```

View File

@ -18,13 +18,14 @@ CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
LOAD_MODE='--mmap'
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE}
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE}
fi

View File

@ -1,31 +0,0 @@
#!/usr/bin/env bash
# MIT license
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: MIT
# If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi

130
examples/sycl/test.sh Executable file
View File

@ -0,0 +1,130 @@
#!/bin/bash
# MIT license
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT
Help() {
cat << EOF
Usage: $(basename "$0") [OPTIONS]
This script processes files with specified options.
Options:
-h, --help Display this help message and exit.
-c, --context <value> Set context length. Bigger need more memory.
-p, --promote <value> Prompt to start generation with.
-m, --model <value> Full model file path.
-mg,--main-gpu <value> Set main GPU ID (0 - n) for single GPU mode.
-sm,--split-mode <value> How to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
-ngl,--n-gpu-layers <value> Max. number of layers to store in VRAM (default: -1)
-lv,--log-verbosity <value> Set the verbosity threshold. Messages with a higher verbosity will be
ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
EOF
}
BIN_FILE=./build/bin/llama-completion
SEED=0
GPUS_SETTING=""
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99
CONTEXT=4096
GGML_SYCL_DEVICE=-1
SPLIT_MODE=layer
LOG_VERBOSE=3
while [[ $# -gt 0 ]]; do
case "$1" in
-c|--context)
CONTEXT=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-p|--promote)
# Option that is a simple flag (boolean)
INPUT_PROMPT="$2"
# Shift once to consume the option flag
shift
shift
;;
-m|--model)
MODEL_FILE="$2"
# Shift twice to consume both the option flag and its value
shift
shift
;;
-mg|--main-gpu)
GGML_SYCL_DEVICE=$2
SPLIT_MODE=none
# Shift twice to consume both the option flag and its value
shift
shift
;;
-sm|--split-mode)
SPLIT_MODE=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-ngl|--n-gpu-layers)
NGL=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-lv|--log-verbosity)
LOG_VERBOSE=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-h|--help)
Help
exit 0
;;
*)
# Handle unknown options or stop processing options
echo "Invalid option: $1"
# Optional: exit script or shift to treat remaining as positional args
exit 1
;;
esac
done
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
echo "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=${UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS}"
if [ $GGML_SYCL_DEVICE -ne -1 ]; then
echo "Use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
else
echo "Use all Intel GPUs, including iGPU & dGPU"
fi
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap

View File

@ -7,5 +7,5 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0
set LOAD_MODE="--mmap"
.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE%

View File

@ -7,5 +7,5 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-completion.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -no-cnv -p %INPUT2% -n 400 -s 0 -e -ngl 99
set LOAD_MODE="--mmap"
.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE%

View File

@ -2279,13 +2279,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ne2 == 1) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
if (ne2 <= 4) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
} else {
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
if (GGML_CUDA_CC_IS_AMD(cc)) {
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
return;
}
}
return;
}
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {

View File

@ -4,26 +4,48 @@
#include "mmvf.cuh"
#include "convert.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const int ids_stride) {
const int row = blockIdx.x;
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
const int tid = threadIdx.x;
int token_idx;
int channel_x;
int channel_y;
int sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
token_idx = ids ? blockIdx.z : 0;
channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
sample_dst = ids ? 0 : blockIdx.z;
}
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y2*2;
dst += token_idx*stride_col_dst;
}
bool use_gate = false;
bool use_bias = false;
@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f(
if (use_gate) {
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
}
const int channel_bias = ids ? channel_x : channel_dst;
if constexpr (has_fusion) {
const int channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f(
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>
template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t ncols, const uint3 nchannels_y,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
template <typename T, typename type_acc, int ncols_dst>
template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 64: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 96: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 128: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 160: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 192: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 224: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 256: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
default: {
GGML_ABORT("fatal error");
@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
const int64_t ids_stride, cudaStream_t stream) {
const bool has_ids = ids != nullptr;
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
ncols_dst, ids_stride, stream);
return;
}
if (has_ids) {
// Single-token MUL_MAT_ID path
constexpr int c_ncols_dst = 1;
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
ncols_dst, ids_stride, stream);
return;
}
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 2:
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 3:
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 4:
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 5:
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 6:
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 7:
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
case 8:
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break;
default:
GGML_ABORT("fatal error");
@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
return;
}
}
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
}
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1);
const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f(
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));

View File

@ -1,5 +1,7 @@
#include "common.cuh"
#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);

View File

@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
// tell the compiler to use as many registers as it wants, see nwarps definition below
template <ggml_type type, int ncols_dst, bool has_fusion>
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const uint32_t ids_stride) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q(
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const uint32_t channel_dst = blockIdx.y;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
const uint32_t sample_dst = blockIdx.z;
uint32_t token_idx = 0;
uint32_t channel_x;
uint32_t channel_y;
uint32_t sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
}
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q(
active_glu = fusion.glu_op;
}
const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) {
const uint32_t channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q(
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y;
}
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q(
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
if constexpr (is_multi_token_id) {
dst += token_idx*stride_col_dst;
}
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q(
}
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst>
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
const uint32_t ids_stride, cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
template <ggml_type type>
@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) {
const int ids_stride, cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst(
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const bool has_ids = ids != nullptr;
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
return;
}
GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 2: {
constexpr int c_ncols_dst = 2;
@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 3: {
constexpr int c_ncols_dst = 3;
@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 4: {
constexpr int c_ncols_dst = 4;
@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 5: {
constexpr int c_ncols_dst = 5;
@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 6: {
constexpr int c_ncols_dst = 6;
@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 7: {
constexpr int c_ncols_dst = 7;
@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 8: {
constexpr int c_ncols_dst = 8;
@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
dims.first, dims.second, 0, ids_stride, stream);
} break;
default:
GGML_ABORT("fatal error");
@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type(
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) {
const int ids_stride, cudaStream_t stream) {
switch (type_x) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
default:
GGML_ABORT("fatal error");
@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q(
GGML_ASSERT( nb0 == ts_dst);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q(
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, stream);
ne03, ne3, s03, s13, s3, ids_stride, stream);
}
void ggml_cuda_op_mul_mat_vec_q(
@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q(
ggml_cuda_mm_fusion_args_device fusion_local{};
mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
}