Compare commits

...

12 Commits

Author SHA1 Message Date
ddh0 680d43936e
Merge c27b0d3d88 into 6b949d1078 2026-04-01 14:05:51 +03:00
Neo Zhang 6b949d1078
sycl : support nvfp4 type in mul_mat (#21227) 2026-04-01 13:54:15 +03:00
Michael Wand 84f82e846c
ggml-cuda: Add generic NVFP4 MMQ kernel (#21074)
* Introduced NVFP4 generic MMQ kernel

* Added extra FP8 guard, hope to solve ci HIP failure

* Rename tiles and use HIP_FP8_AVAILABLE

* Removed remaning FP8 straggler and added const int

* Const

* Removed DECL_MMQ_CASE artifact

* Removed newline

* Removed space after else

* Changed HIP FP8 NVFP4 conversion gate

* Added new line to bottom of mmq.cu 270

* Removed extra spaces

* Removed single space in front of else on line 814

* Added NVFP4 to generate cu script so HIP can see it, further tightened logic

* Include generated mmq-instance-nvfp4.cu

* Added NVFP4 mmq to HIP Check ignore list

* Update ggml/src/ggml-cuda/mmq.cuh

Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Added function name ending for end if

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Added function names to closing endif

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-01 12:04:58 +02:00
Ettore Di Giacinto e1cb817483
memory: respect unified KV cache in hybrid memory for eval tasks (#21224)
The hybrid memory paths (`llama-memory-hybrid.cpp` and
`llama-memory-hybrid-iswa.cpp`) always used sequential equal split,
ignoring the unified KV cache flag. This caused hellaswag, winogrande,
and multiple-choice evaluations to fail on hybrid models (models with
both attention and recurrent/SSM layers, such as Qwen3.5-35B-A3B) with:

  split_equal: sequential split is not supported when there are
  coupled sequences in the input batch (you may need to use the
  -kvu flag)

PR #19954 fixed this for `llama-kv-cache-iswa.cpp` by automatically
enabling unified KV mode and setting n_parallel >= 4 for multi-choice
eval tasks. However, the hybrid memory paths were not updated.

This commit mirrors the iswa fix: use non-sequential split when KV
cache is unified (n_stream == 1), which is automatically set by
llama-perplexity for hellaswag/winogrande/multiple-choice since #19954.

Tested on Qwen3.5-35B-A3B (hybrid attention+SSM MoE model):
- HellaSwag: 83.0% (400 tasks)
- Winogrande: 74.5% (400 tasks)
- MMLU: 41.2%
- ARC-Challenge: 56.2%
- TruthfulQA: 37.7%
All previously failed with llama_decode() error.
2026-04-01 12:50:17 +03:00
uvos 88d5f8ffc3
CUDA/HIP: Fix kernel slection for mmvq mmid kernel to align host selection with device launch bounds (#21238)
The conditions cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE and cc >= GGML_CUDA_CC_TURING match all non-nvidia devices. This causes us to attempt to launch the kernel for batch sizes with larger configurations than our launch bounds on HIP devices. This pr fixes the conditionals in get_mmvq_mmid_max_batch.

Fixes #21191
2026-04-01 10:21:20 +02:00
Georgi Gerganov d43375ff7f
ggml : fix RWKV ops thread assignment (#21226) 2026-04-01 11:10:25 +03:00
Taimur Ahmad 2b86e5cae6
ggml-cpu: fix fallback for RVV kernels without zvfh (#21157)
* ggml-cpu: refactor sgemm; fix rvv checks

* ggml-cpu: refactor rvv kernels; set zvfbfwma default to off
2026-04-01 11:10:03 +03:00
Anav Prasad 88458164c7
CUDA: Add Flash Attention Support for Head Dimension 512 (#20998)
* flash attention support for head dimension 512 added

* FA D=512 - match 576 configs, limit ncols2, revert vec cap

* fix HIP tile kernel build for D=512

* fix HIP tile kernel occupancy for D=512 on AMD

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* fix tile FA compilation

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-01 09:07:24 +02:00
Ed Addario 4951250235
llama : refactor llama_model_quantize_params to expose a pure C interface (#20346)
* Refactor llama_model_quantize_params to expose a pure C interface

* Restore comment and cleanup struct def

* Code review refactoring

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Code review refactoring

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-04-01 08:43:00 +03:00
Reese Levine 82764c341a
ggml webgpu: quantized buffers to u32 + wider browser/device support (#21046)
* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs
2026-04-01 08:38:24 +03:00
ddh0 c27b0d3d88 fix typo
missing word "of"
2026-03-29 17:10:34 -05:00
ddh0 3cf161d29c contrib : clarify code origin guidelines 2026-03-29 16:51:07 -05:00
44 changed files with 1006 additions and 578 deletions

View File

@ -1,4 +1,6 @@
# Contributors
# CONTRIBUTION POLICY
## Contributors
The project differentiates between 3 levels of contributors:
@ -6,7 +8,7 @@ The project differentiates between 3 levels of contributors:
- Collaborators (Triage): people with significant contributions, who may be responsible for some parts of the code, and are expected to maintain and review contributions for the code they own
- Maintainers: responsible for reviewing and merging PRs, after approval from the code owners
# AI Usage Policy
## AI Usage Policy
> [!IMPORTANT]
> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity.
@ -26,7 +28,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
# Pull requests (for contributors & collaborators)
## Pull requests (for contributors & collaborators)
Before submitting your PR:
- Search for existing PRs to prevent duplicating efforts
@ -54,7 +56,7 @@ After submitting your PR:
- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for fixing related issues and reviewing related PRs
# Pull requests (for maintainers)
## Pull requests (for maintainers)
- Squash-merge PRs
- Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)`
@ -68,7 +70,17 @@ Maintainers reserve the right to decline review or close pull requests for any r
- The pull request duplicates an existing one.
- The contributor fails to adhere to this contributing guide or the AI policy.
# Coding guidelines
## Code Origin Disclosure Policy
*Contributors* and *collaborators*, as defined above, must disclose the origin of their code, if they did not write it themselves, to avoid potential licensing conflicts or intellectual property right violations.
Contributions that can be shown to originate from one or more of the following sources, in whole or in part, will not be accepted:
- [ik_llama.cpp](https://github.com/ikawrakow/ik_llama.cpp)
- [Iwan Kawrakow](https://github.com/ikawrakow)
Note that maintainers reserve the right to reject code on the basis of its origin, even if not explicitly listed above.
## Coding guidelines
- Avoid adding third-party dependencies, extra files, extra headers, etc.
- Always consider cross-compatibility with other operating systems and architectures
@ -97,7 +109,7 @@ Maintainers reserve the right to decline review or close pull requests for any r
![matmul](media/matmul.png)
# Naming guidelines
## Naming guidelines
- Use `snake_case` for function, variable and type names
- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963)
@ -156,7 +168,7 @@ Maintainers reserve the right to decline review or close pull requests for any r
- _(TODO: abbreviations usage)_
# Preprocessor directives
## Preprocessor directives
- _(TODO: add guidelines with examples and apply them to the codebase)_
@ -165,7 +177,7 @@ Maintainers reserve the right to decline review or close pull requests for any r
#endif // FOO
```
# Code maintenance
## Code maintenance
- Existing code should have designated collaborators and/or maintainers specified in the [CODEOWNERS](CODEOWNERS) file responsible for:
- Reviewing and merging related PRs
@ -182,13 +194,13 @@ Maintainers reserve the right to decline review or close pull requests for any r
- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md)
# Documentation
## Documentation
- Documentation is a community effort
- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference
- When you notice incorrect or outdated documentation, please update it
# Resources
## Further resources
The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects:

View File

@ -166,15 +166,16 @@ if (NOT MSVC)
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
endif()
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON)
option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

View File

@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
{
n_tasks = n_threads;
} break;
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
{
n_tasks = n_threads;
const int64_t n_heads = node->src[1]->ne[1];
n_tasks = MIN(n_threads, n_heads);
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:

View File

@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
}
#endif
#if defined(__riscv_zvfh)
template <>
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
#if defined(__riscv_v_intrinsic)
template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
#endif
#if defined(__riscv_zvfh)
template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
#endif
#if defined(__riscv_zvfbfwma)
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -272,7 +277,7 @@ inline float hsum(__m512 x) {
}
#endif // __AVX512F__
#if defined(__riscv_zvfh)
#if defined(__riscv_v_intrinsic)
inline float hsum(vfloat32m1_t x) {
return __riscv_vfmv_f_s_f32m1_f32(
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
@ -379,19 +384,7 @@ template <> inline __m256bh load(const float *p) {
}
#endif
#if defined(__riscv_zvfh)
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
}
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
}
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
}
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
}
#if defined(__riscv_v_intrinsic)
template <> inline vfloat32m1_t load(const float *p) {
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
}
@ -406,6 +399,21 @@ template <> inline vfloat32m8_t load(const float *p) {
}
#endif
#if defined(__riscv_zvfh)
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
}
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
}
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
}
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
}
#endif
#if defined(__riscv_zvfbfwma)
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
}
template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) {
return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4());
}
#endif
#if defined(__riscv_zvfh)
#if defined(__riscv_v_intrinsic)
template <typename T> T set_zero();
template <> inline vfloat16mf2_t set_zero() {
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
}
template <> inline vfloat16m1_t set_zero() {
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
}
template <> inline vfloat16m2_t set_zero() {
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
}
template <> inline vfloat16m4_t set_zero() {
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
}
template <> inline vfloat32m1_t set_zero() {
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
}
@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() {
#if defined(__riscv_v_intrinsic)
template <typename T> size_t vlmax() {
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
#if defined (__riscv_zvfh)
else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
#endif
#if defined (__riscv_zvfbfwma)
else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
#endif
return 0;
}
#endif
@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
params->ith, params->nth};
tb.matmul(m, n);
return true;
#elif defined(__riscv_zvfh)
#elif defined(__riscv_v_intrinsic)
#if LMUL == 1
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
k, (const float *)A, lda,
@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
return true;
}
#elif defined(__riscv_zvfbfwma)
#if LMUL == 1
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#elif LMUL == 2
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#else // LMUL = 4
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#endif
return tb.matmul(m, n);
if (Btype == GGML_TYPE_BF16) {
#if LMUL == 1
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#elif LMUL == 2
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#else // LMUL = 4
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#endif
return tb.matmul(m, n);
}
#endif
return false;
}

View File

@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
const int h_start = (HEADS * (ith )) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
const int h_start = (HEADS * (ith )) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
const int h_start = (HEADS * (ith )) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * r = (float *) dst->src[0]->data;
float * w = (float *) dst->src[1]->data;

View File

@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int np = (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t sum_00 = svdup_n_f16(0.0f);
svfloat16_t sum_01 = svdup_n_f16(0.0f);
@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
np = n;
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
size_t vl = __riscv_vsetvlmax_e32m4();
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
size_t vl = __riscv_vsetvlmax_e32m4();
// initialize accumulators to all zeroes
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
// initialize accumulators to all zeroes
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
// calculate step size
const size_t epr = __riscv_vsetvlmax_e16m2();
const size_t step = epr * 2;
int np = (n & ~(step - 1));
// calculate step size
const size_t epr = __riscv_vsetvlmax_e16m2();
const size_t step = epr * 2;
const int np = (n & ~(step - 1));
// unroll by 2 along the row dimension
for (int i = 0; i < np; i += step) {
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
// unroll by 2 along the row dimension
for (int i = 0; i < np; i += step) {
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
}
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
}
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
// leftovers
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
// leftovers
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
}
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
}
// reduce
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
// reduce
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
np = n;
#else
const int np = 0;
#endif
#else
const int np = (n & ~(GGML_F16_STEP - 1));
@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
}
// leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#else
for (int i = 0; i < n; ++i) {
// scalar path
const int np = 0;
#endif
// scalar and leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
s[i] = (float)sumf[i];
@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic
#if defined (__riscv_zvfh)
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
int np = (n & ~(step - 1));
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
int np = (n & ~(step - 1));
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
#else
// fall to scalar path
const int np = 0;
#endif
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
}
}
#else
// scalar path
const int np = 0;
#endif
// leftovers
// scalar and leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
const int ggml_f16_step = 2 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np = (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ay1, ay2;
for (int i = 0; i < np; i += ggml_f16_step) {
@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
np = n;
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
const int np = (n & ~(step - 1));
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
int np = (n & ~(step - 1));
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
#else
// fall to scalar path
const int np = 0;
#endif
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
#else
// scalar path
const int np = 0;
#endif
// scalar and leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
}
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }

View File

@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
}
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
#ifdef FP8_AVAILABLE
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
#if defined(GGML_USE_HIP) && defined(CDNA3)
// ROCm dose not support fp8 in software on devices with fp8 hardware,
#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
// ROCm does not support fp8 in software on devices with fp8 hardware,
// but CDNA3 supports only e4m3_fnuz (no inf).
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
#else
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
return static_cast<float>(xf) / 2;
#else
NO_DEVICE_CODE;
#endif // FP8_AVAILABLE
#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
return static_cast<float>(xf) / 2;
#else
if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
return 0.0f;
}
const int exp = (x >> 3) & 0xF;
const int man = x & 0x7;
float raw;
if (exp == 0) {
raw = ldexpf((float) man, -9);
} else {
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
}
return static_cast<float>(raw / 2);
#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {

View File

@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16(
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
NO_DEVICE_CODE;
return;
}
@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
// The number of viable configurations for Deepseek is very limited:
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);

View File

@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 512: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);

View File

@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
@ -767,7 +785,7 @@ static __global__ void flash_attn_tile(
#ifdef GGML_USE_WMMA_FATTN
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
#endif // GGML_USE_WMMA_FATTN
(use_logit_softcap && !(DV == 128 || DV == 256))
(use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512))
) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
if constexpr (DV == 512) {
if constexpr (DKQ == 576) {
if (use_gqa_opt && gqa_ratio % 16 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
}
}
if constexpr (DV <= 256) {
if constexpr (DKQ <= 512) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
if constexpr (DV <= 256) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
GGML_ABORT("fatal error");
}
@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
extern DECL_FATTN_TILE_CASE(576, 512);

View File

@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 512:
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
break;
case 576: {
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
case 128:
case 112:
case 256:
if (V->ne[0] != K->ne[0]) {
case 512:
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
break;
@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use MFMA flash attention for CDNA (MI100+):
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
// MMA vs tile crossover benchmarked on MI300X @ d32768:
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)

View File

@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
#ifdef FP8_AVAILABLE
case GGML_TYPE_NVFP4:
#endif // FP8_AVAILABLE
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:

View File

@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_NVFP4:
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_MXFP4:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_NVFP4:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q2_K:
return MMQ_Q8_1_DS_LAYOUT_D2S6;
case GGML_TYPE_Q3_K:
@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
}
}
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kb0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
constexpr int rows_per_warp = warp_size / threads_per_row;
const int kbx = threadIdx.x % threads_per_row;
const int row_in_warp = threadIdx.x / threads_per_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if constexpr (need_check) {
i = min(i, i_max);
}
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
const int kqs = 16 * kbx;
const int ksc = 4 * kbx;
#pragma unroll
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
#else
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
// Used for Q3_K, IQ2_S, and IQ2_XS
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);

View File

@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
// Host function: returns the max batch size for the current arch+type at runtime.
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return MMVQ_MAX_BATCH_SIZE;
}
if (cc >= GGML_CUDA_CC_TURING) {
return get_mmvq_mmid_max_batch_turing_plus(type);
}
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return MMVQ_MAX_BATCH_SIZE;
}
if (cc >= GGML_CUDA_CC_TURING) {
return get_mmvq_mmid_max_batch_turing_plus(type);
}
return get_mmvq_mmid_max_batch_pascal_older(type);
}
// AMD
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return get_mmvq_mmid_max_batch_rdna4(type);
}
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return get_mmvq_mmid_max_batch_rdna3(type);
}
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
}
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return get_mmvq_mmid_max_batch_cdna(type);
}
if (GGML_CUDA_CC_IS_GCN(cc)) {
return get_mmvq_mmid_max_batch_gcn(type);
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return get_mmvq_mmid_max_batch_rdna4(type);
}
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return get_mmvq_mmid_max_batch_rdna3(type);
}
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
}
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return get_mmvq_mmid_max_batch_cdna(type);
}
if (GGML_CUDA_CC_IS_GCN(cc)) {
return get_mmvq_mmid_max_batch_gcn(type);
}
}
return MMVQ_MAX_BATCH_SIZE;
}

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(512, 512);

View File

@ -3,7 +3,7 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
@ -35,7 +35,7 @@ TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
@ -83,6 +83,8 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 72:
continue
if head_size_kq == 512 and ncols2 not in (4, 8):
continue
if head_size_kq != 576 and ncols2 in (16, 32):
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32):

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_NVFP4);

View File

@ -23,6 +23,7 @@
#include "ggml-impl.h"
#include "ggml-sycl.h"
#include "presets.hpp"
#include "type.hpp"
#include "sycl_hw.hpp"
namespace syclexp = sycl::ext::oneapi::experimental;
@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
return val;
}
static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
const uint32_t bits = x * (x != 0x7F && x != 0xFF);
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
return static_cast<float>(xf) / 2;
}
#endif // GGML_SYCL_COMMON_HPP

View File

@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
});
}
template <typename dst_t>
static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
GGML_ASSERT(k % QK_NVFP4 == 0);
const int nb = k / QK_NVFP4;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_nvfp4(vx, y, k);
});
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
#ifdef GGML_SYCL_HAS_BF16
@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default:
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
#ifdef GGML_SYCL_HAS_BF16
@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default:
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}

View File

@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
}
}
template <typename dst_t>
static void dequantize_block_nvfp4(
const void * __restrict__ vx,
dst_t * __restrict__ yy,
const int64_t ne) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i = item_ct1.get_group(2);
const int tid = item_ct1.get_local_id(2);
const int64_t base = i * QK_NVFP4;
if (base >= ne) {
return;
}
const block_nvfp4 * x = (const block_nvfp4 *) vx;
const block_nvfp4 & xb = x[i];
const int sub = tid / (QK_NVFP4_SUB / 2);
const int j = tid % (QK_NVFP4_SUB / 2);
const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
}
#endif // GGML_SYCL_DEQUANTIZE_HPP

View File

@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
}
}
static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_NVFP4 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
case GGML_TYPE_MXFP4:
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_NVFP4:
mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
default:
GGML_ABORT("fatal error");
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
}
}
GGML_UNUSED(src1);

112
ggml/src/ggml-sycl/type.hpp Normal file
View File

@ -0,0 +1,112 @@
#pragma once
#include <sycl/sycl.hpp>
#include <cstdint>
#include <limits>
inline uint8_t float_to_e4m3(float f)
{
if (sycl::isnan(f)) {
return 0x7F; // Canonical NaN (positive)
}
uint32_t bits = sycl::bit_cast<uint32_t>(f);
uint32_t sign = (bits >> 31) & 0x1u;
uint32_t exp = (bits >> 23) & 0xFFu;
uint32_t mant = bits & 0x7FFFFFu;
// Zero
if (exp == 0 && mant == 0) {
return static_cast<uint8_t>(sign << 7);
}
// Extract biased exponent and mantissa for FP8
int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127)
uint32_t m = mant;
// Handle very large values → NaN (NVIDIA behavior for E4M3)
if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
return static_cast<uint8_t>((sign << 7) | 0x7F);
}
// Handle subnormals and normal numbers
if (e < -6) { // smallest normal exponent is -6
// Subnormal in FP8: shift mantissa right
int shift = -6 - e;
m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
if (shift > 23) m = 0;
} else {
// Normal number: adjust exponent bias from 127 to 7
int new_exp = e + 7;
m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
m |= (static_cast<uint32_t>(new_exp) << 3);
}
// Round-to-nearest-even (simple guard + round bit)
// For better accuracy you can add sticky bit, but this is sufficient for most use cases
uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
if (round_bit) {
m += 1;
// Carry into exponent if mantissa overflows
if ((m & 0x8u) != 0) {
m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
// If exponent overflows after carry → NaN
if ((m >> 3) > 14) {
return static_cast<uint8_t>((sign << 7) | 0x7F);
}
}
}
uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
return result;
}
inline float e4m3_to_float(uint8_t x)
{
if (x == 0) return 0.0f;
uint8_t sign = (x >> 7) & 0x1u;
uint8_t exp = (x >> 3) & 0xFu;
uint8_t mant = x & 0x7u;
// NaN (NVIDIA uses 0x7F / 0xFF as NaN)
if (exp == 0xF && mant != 0) {
return std::numeric_limits<float>::quiet_NaN();
}
if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
return std::numeric_limits<float>::quiet_NaN();
}
float val;
if (exp == 0) {
// Subnormal
val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
} else {
// Normal: implicit leading 1 + bias 7
val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
}
return sign ? -val : val;
}
// The actual type definition
struct __nv_fp8_e4m3 {
uint8_t raw;
__nv_fp8_e4m3() = default;
explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
operator float() const { return e4m3_to_float(raw); }
operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
// Allow direct access for vector loads/stores
operator uint8_t&() { return raw; }
operator uint8_t() const { return raw; }
};
using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;

View File

@ -15,6 +15,7 @@
#include "dpct/helper.hpp"
#include "ggml.h"
#include "type.hpp"
#include "quants.hpp"
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
return x32;
}
static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
int x32 = x16[2*i32 + 0] << 0;
x32 |= x16[2*i32 + 1] << 16;
return x32;
}
static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 =
@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
return d * sumi;
}
#define VDR_NVFP4_Q8_1_MMVQ 4
#define VDR_NVFP4_Q8_1_MMQ 8
static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
const block_q8_1 * __restrict__ bq8_1,
const int32_t & iqs) {
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
const int32_t iqs0 = iqs + 2*i;
const int32_t iqs1 = iqs0 + 1;
const int32_t is = iqs0 >> 1;
const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
const int32_t i8 = ((is & 1) << 2);
int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
sum += d * float(sumi);
}
return sum;
}
static __dpct_inline__ float
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,

View File

@ -1219,9 +1219,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
break;
}
}
@ -1334,9 +1333,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
// Use f16 inside the shader for quantized types
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
variant += std::string("_") + src0_name;
break;

View File

@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
@ -171,6 +171,7 @@ struct webgpu_buf_pool {
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
@ -507,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
while (blocking_wait) {
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
@ -728,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
@ -2694,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
total_offset + (size - remaining_size), remaining_size);
} else {
// wait for WriteBuffer to complete
buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
}
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
}

View File

@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
}
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_src0_u16_at(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_src0_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = src0[word_idx];
if (shift == 0u) {
return lo;
}
let hi = src0[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_src0_f16_at(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
return f16(packed[0]);
}
#endif
#ifdef Q4_0_T
struct q4_0 {
d: f16,

View File

@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
#define KV_TYPE u32
#else
#define KV_TYPE f16
#endif
@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix;
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
#define BLOCK_SIZE_BYTES 18u
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
#define BLOCK_SIZE_BYTES 34u
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#if defined(KV_Q4_0) || defined(KV_Q8_0)
fn load_k_u16_at(byte_offset: u32) -> u32 {
let word = K[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_k_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = K[word_idx];
if (shift == 0u) {
return lo;
}
let hi = K[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_v_u16_at(byte_offset: u32) -> u32 {
let word = V[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_v_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = V[word_idx];
if (shift == 0u) {
return lo;
}
let hi = V[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn f16_from_u16(bits: u32) -> f16 {
let packed = unpack2x16float(bits);
return f16(packed[0]);
}
#endif
struct Params {
offset_q: u32,
offset_k: u32,
@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;

View File

@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = load_src0_f16_at(block_byte_base);
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 42u;
const BLOCK_SIZE_BYTES = 84u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx + 40u];
let dmin = src0[scale_idx + 41u];
let d = load_src0_f16_at(block_byte_base + 80u);
let dmin = load_src0_f16_at(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 55u;
const BLOCK_SIZE_BYTES = 110u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx + 54u];
let d = load_src0_f16_at(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
let scale_0 = src0[scale_idx + 48u + (2u*i)];
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
let hmask_0 = src0[scale_idx + (2u*i)];
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
let qs_0 = src0[scale_idx + 16u + (2u*i)];
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 72u;
const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// Map k_in_block to loop structure:
@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 88u;
const BLOCK_SIZE_BYTES = 176u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// The original loop processes elements in groups of 64
@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13_word = ql13_flat / 4u;
let ql13 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql13_word],
src0[scale_idx + 2u * ql13_word + 1u]
));
let ql13_b = get_byte(ql13, ql13_flat % 4u);
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24_word = ql24_flat / 4u;
let ql24 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql24_word],
src0[scale_idx + 2u * ql24_word + 1u]
));
let ql24_b = get_byte(ql24, ql24_flat % 4u);
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh_word = qh_flat / 4u;
let qh = bitcast<u32>(vec2(
src0[scale_idx + 64u + 2u * qh_word],
src0[scale_idx + 64u + 2u * qh_word + 1u]
));
let qh_b = get_byte(qh, qh_flat % 4u);
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc_word = sc_idx / 4u;
let sc = bitcast<u32>(vec2(
src0[scale_idx + 96u + 2u * sc_word],
src0[scale_idx + 96u + 2u * sc_word + 1u]
));
let sc_val = get_byte_i32(sc, sc_idx % 4u);
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = src0[scale_idx + 104u];
let d = load_src0_f16_at(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View File

@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q4_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 18u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 20u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = f32(src0[scale_idx + 1u]);
let d = f32(load_src0_f16_at(block_byte_base));
let m = f32(load_src0_f16_at(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 22u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = f32(load_src0_f16_at(block_byte_base));
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 24u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 34u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 36u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
let aligned = byte_offset & ~3u;
let idx = bbase + aligned / 2u;
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
}
const BLOCK_SIZE_BYTES = 210u;
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d_raw = load_u32_at(bbase, 208u);
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
let d = f32(load_src0_f16_at(bbase + 208u));
let ql1_u32 = load_u32_at(bbase, q_offset_l);
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View File

@ -380,22 +380,33 @@ extern "C" {
size_t n_samplers;
};
struct llama_model_tensor_override {
const char * pattern;
enum ggml_type type;
};
struct llama_model_imatrix_data {
const char * name;
const float * data;
size_t size;
};
// model quantization parameters
typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
bool dry_run; // calculate and show the final quantization size without performing quantization
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
bool dry_run; // calculate and show the final quantization size without performing quantization
const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data
const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides
const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides
const int32_t * prune_layers; // pointer to layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {

View File

@ -139,7 +139,11 @@ def main():
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
}
functions = parse_log_file(log_file)

View File

@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
}
if (ubatch.n_tokens == 0) {

View File

@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
}
if (ubatch.n_tokens == 0) {

View File

@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::map<i
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
@ -188,10 +187,9 @@ struct quantize_state_impl {
model(model), params(params)
{
// compile regex patterns once - they are expensive
if (params->tensor_types) {
const auto & tensor_types = *static_cast<const std::vector<tensor_type_option> *>(params->tensor_types);
for (const auto & [tname, qtype] : tensor_types) {
tensor_type_patterns.emplace_back(std::regex(tname), qtype);
if (params->tt_overrides) {
for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) {
tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type);
}
}
}
@ -857,12 +855,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
constexpr bool use_mmap = false;
#endif
llama_model_kv_override * kv_overrides = nullptr;
if (params->kv_overrides) {
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}
const llama_model_kv_override * kv_overrides = params->kv_overrides;
std::vector<std::string> splits = {};
llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr,
fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
@ -879,9 +872,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (params->only_copy) {
ftype = ml.ftype;
}
std::unordered_map<std::string, std::vector<float>> i_data;
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
if (params->imatrix) {
imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) {
i_data.emplace(p->name, std::vector<float>(p->data, p->data + p->size));
}
imatrix_data = & i_data;
if (imatrix_data) {
LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n",
__func__, (int)imatrix_data->size());
@ -902,7 +899,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
for (const int32_t * p = params->prune_layers; * p != -1; p++) {
prune_list.push_back(* p);
}
}
// copy the KV pairs from the input file
@ -916,20 +915,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
if (params->kv_overrides) {
const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
for (const auto & o : overrides) {
if (o.key[0] == 0) break;
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) {
if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64);
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
// Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64));
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64));
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool);
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out.get(), o->key, o->val_str);
} else {
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key);
}
}
}

View File

@ -13,13 +13,10 @@
#include <unordered_map>
#include <map>
#include <fstream>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <filesystem>
// result of parsing --tensor-type option
// (changes to this struct must be reflected in src/llama-quant.cpp)
// changes to this struct must also be reflected in src/llama-quant.cpp
struct tensor_type_option {
std::string name;
ggml_type type = GGML_TYPE_COUNT;
@ -491,7 +488,6 @@ static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
if (argc < 3) {
usage(argv[0]);
}
@ -584,8 +580,16 @@ int main(int argc, char ** argv) {
std::vector<std::string> imatrix_datasets;
std::unordered_map<std::string, std::vector<float>> imatrix_data;
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data);
std::vector<llama_model_imatrix_data> i_data;
std::vector<llama_model_tensor_override> t_override;
if (!imatrix_data.empty()) {
params.imatrix = &imatrix_data;
i_data.reserve(imatrix_data.size() + 1);
for (const auto & kv : imatrix_data) {
i_data.push_back({kv.first.c_str(), kv.second.data(), kv.second.size()});
}
i_data.push_back({nullptr, nullptr, 0}); // array terminator
params.imatrix = i_data.data();
{
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE);
@ -603,7 +607,6 @@ int main(int argc, char ** argv) {
kvo.val_str[127] = '\0';
kv_overrides.emplace_back(std::move(kvo));
}
{
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES);
@ -611,7 +614,6 @@ int main(int argc, char ** argv) {
kvo.val_i64 = imatrix_data.size();
kv_overrides.emplace_back(std::move(kvo));
}
if (m_last_call > 0) {
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS);
@ -623,13 +625,19 @@ int main(int argc, char ** argv) {
if (!kv_overrides.empty()) {
kv_overrides.emplace_back();
kv_overrides.back().key[0] = 0;
params.kv_overrides = &kv_overrides;
params.kv_overrides = kv_overrides.data();
}
if (!tensor_type_opts.empty()) {
params.tensor_types = &tensor_type_opts;
t_override.reserve(tensor_type_opts.size() + 1);
for (const auto & tt : tensor_type_opts) {
t_override.push_back({tt.name.c_str(), tt.type});
}
t_override.push_back({nullptr, GGML_TYPE_COUNT}); // array terminator
params.tt_overrides = t_override.data();
}
if (!prune_layers.empty()) {
params.prune_layers = &prune_layers;
prune_layers.push_back(-1); // array terminator
params.prune_layers = prune_layers.data();
}
llama_backend_init();