feat: perf opt add set rows (#59)

* Add power management utilities to NPU device context and update DCVS settings

* Update DCVS settings in power_utils to use v3 API and enhance power management

* wip

* Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance

* use lut

* wip

* fix test failure

* wip

* Refactor load_qual_block_generic to improve block handling and optimize vector operations

* Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling

* Refactor flash_attn_impl to optimize mask l2 prefetch

* wip

* wip

* wip

* wip

* add log

* link against shared libraries instead of static ones

* fix swiglu

* wip

* refactor expf_fix to handle overflow for different data types

* enhance is_glu_op_supported to validate shapes for multiple sources

* wip

* refactor logging macros to use hexagon namespace and improve formatting

* fix printf format error

* wip

* refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias

* rename

* feat: enhance fa with mask

* wip

* wip

* refactor: replace instances of Q6_V_vzero() with kZeroV for consistency

* wip

* wip

* wip

* fix: improve address alignment check in HVX_Vector handling

* refactor: streamline vector dot product implementations for improved readability

* refactor: q4k add hvx intrinsic impl

* refactor: enhance dequantize_row_q4_K for clarity and performance

* refactor: optimize scale mask usage in dequantization functions for improved performance

* refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements

* refactor: move GLU operation implementation into separated file

* sync after swiglu

* wip

* wip

* wip

* feat: increase prc main thread stack size

* fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant

* wip

* feat: add optimized vector operations for exponential and division with overflow handling

* wip

* feat: refactor exponential function to handle overflow and underflow with improved logic

* wip

* wip

* feat: add vector loading and scaling functions for improved performance in block processing

* wip

* feat: optimize block loading by refactoring scale index handling for improved performance

* use Q6_Vb_vlut32_VbVbR_nomatch instead

* feat: enhance scale loading by adding static assertion and restructuring block handling

* wip

* feat: refactor vec_dot_product_mixed_impl for improved clarity and performance

* wip

* feat: simplify vector loading functions and improve alignment handling

* wip

* feat: enhance scale loading mask with quantization block size validation

* wip

* feat: implement make_scale_load_mask function and refactor vector handling in vec_ops

* feat: enhance load_dual_block_generic to include scale indices for improved vector loading

* revert q8 dequant

* wip

* feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods

* wip

* wip

* add qurt_mutex

* Add DMA transfer class and integrate into thread pool

* Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel

* fix dma crash

* fix failed unit tests

* wip

* use alignas

* Improve DMA transfer error handling and update descriptor completion check

* Fix VTCM cache size calculation in element-wise operations

* Add cache clean operations before DMA transfers in element-wise operations

* reduce cache clean operations

* Refactor DMA transfer functions to support 1D operations and rename for clarity

* Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization

* Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations

* wip

* Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic

* fix 2d dma

* feat: add DMA plane cache

* rename

* wip

* use memcpy for debug

* fix cache plane calc

* refactor: remove debug logging from mul_mat_impl and optimize cache handling

* rename

* fix 2d dma type

* refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions

* refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions

* wip

* wip

* move op impl into sub dir

* add log

* fix: correct pointer usage in mul_mat_gemv_impl for next plane access

* fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl

* fix: fix crash by using the entire row bytes

* wip

* wip

* fix: prevent parallelization for scalar src1 in is_mul_mat_supported

* fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary

* wip

* fix: enable thread barrier for mul multiplication operations

* feat: add synchronization checks for tensor operations and update related functions

* wip

* fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations

* Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations"

This reverts commit af3441e67e706b2e5122369dc160353796867dd3.

* wip

* wip

* add comment

* fix: improve DMA transfer handling in mul_mat_gemv_impl for quantized source tensors

* add log

* try fix mulmat gemv

* wip

* fix: enhance DMA transfer handling in mul_mat_gemv_impl for quantized source tensors

* fix: optimize cache offset calculation and remove redundant swap in mul_mat_gemv_impl

* fix: refactor DMA transfer handling in mul_mat_gemv_impl for improved clarity and maintainability

* wip

* wip

* wip

* fix: enhance mul_mat_impl for improved cache handling and clarity

* fix: refactor tensor unflattening and DMA transfer initialization for improved clarity and type safety

* fix: improve cache handling of quant

* wip

* fix: improve cache handling in mul_mat_impl and mul_mat_gemv_impl for better memory efficiency

* rename

* add load_hexa_block_generic

* wip

* extract dequant block into separated function

* refactor: enhance dequantization functions with table parameter

* fix load_dual_block_generic

* refactor: rename dequantization functions for clarity and enhance block handling

* refactor: simplify dequantization logic by consolidating block handling and removing unused parameters

* wip

* wip

* feat: add make_qs_load_mask function and update load_dual_block_generic to use qs_indices

* fix load_dual_block_generic

* refactor: update load functions to use qs_indices for improved block loading

* wip

* fix: update loop indices and boundary checks to use size_t for better efficiency

* wip

* update make_scale_load_mask, to make it available for q8

* feat: add vec_dot_product_quant_impl for quantized dot product computation

* refactoring: move come quant func to dedicated file

* refactor: rename dequantization functions for clarity and consistency

* wip

* feat: enhance vec_dot_product_quant_impl with dual dequantization and improved assertions

* add vec_dot_product_vqf32_q40_f32

* wip

* wip

* wip

* wip

* implement vec_mpy_qf32_qf32_qf32 function and update vec_dot_product_vqf32_q40_f32 to use it

* wip

* add src0_plane_write_cache_offset

* wip

* enhance mul_mat_f32 to handle NPU_DATA_TYPE_Q4_0 for quantized matrix multiplication

* wip

* wip

* update test func

* refactor mul_mat_gemv_quant_impl to use get_nb for row stride and remove unused test function in init_f16_f32_table

* wip

* Add support for 4-block dequantization in vec_quant and update dot product implementation

* Refactor vec_dot_product_quant_impl to improve variable handling and enhance readability

* Refactor vec_dot_product_quant_impl to replace template function  with inline vector operations

* use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32

* Revert "use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32"

This reverts commit 54839166fddbe40a0392adee5863c59070ccdbe4.

* wip

* improve log print in graph

* Refactor batched_row_dot to accept additional arguments and remove batched_row_dot_with_table

* Refactor synchronization functions to include previous operation and NE type parameters

* Refactor synchronization checks in several operations

* Update synchronization checks to include NPU_OP_COUNT in required conditions

* Add performance tracking to buffer management functions

* add memset

* add log

* fix: update backend device type from ACCEL to IGPU

* fix comment

* add get/set rows

* feat: implement row operation support checks in is_rows_supported

* feat: add support for I64 data type in rows operations

* feat: implement set_rows functionality for I32 and I64 data types

* wip

* fix set_rows

* feat: extend is_rows_supported to allow F32 data type in destination

* wip

* feat: rename set_rows function, add generic to its name

* disable q4_k

* move ops to separated file

* rename: op_impl -> op_registry

* refactor: update get_data_type struct to include output type for unary operations

* refactor: simplify vec_trans_impl by removing parameterized overload and using variadic templates

* add vec_trans_with_half_ret_impl

* add NPU_OP_CPY

* refactor: enhance is_unary_op_supported to handle non-continuous rows and add type support logging

* refactor: update vec_trans_with_half_ret_impl to use processed_bytes for clarity and accuracy

* wip

* refactor: optimize dequantize_vec_q40_qf32_4blocks by improving shuffling logic and reducing redundancy

* refactor: improve performance of vec_dot_product and dequantize functions by optimizing shuffling logic

* wip

* add dequantize_vec_q40_qf32_6blocks

* feat: add load_dequant_vec_q40_qf32_6blocks function for 6-block dequantization

* feat: enhance vec_dot_product_quant_impl with 6-element processing loop for improved performance

* Revert "feat: enhance vec_dot_product_quant_impl with 6-element processing loop for improved performance"

This reverts commit a5c8fa3e4d9a2d89c8c0821c936c0466e0af7869.

since there's a performance degradation

* fix: correct load_hexa_block_generic return type and update dequantization logic

* wip

* wip

* feat: add make_q40_qs_load_mask function and update vec_dot_product_vqf32_q40_f32

* fix dequant load

* add debug log

* wip

* wip

* fix shuffle index array

* refactor: simplify load mask generation and improve index shuffling for q4 blocks

* wip

* wip

* fix comment

* wip

* update ops.md

* update ops.md by create_ops_docs.py

# Conflicts:
#	docs/ops.md
This commit is contained in:
nullname 2025-10-30 21:51:15 +08:00 committed by GitHub
parent 38ae191c55
commit e6a5f7baa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 18427 additions and 501 deletions

View File

@ -12,105 +12,105 @@ Legend:
- 🟡 Partially supported by this backend
- ❌ Not supported by this backend
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | qualcomm | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ |
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |

17663
docs/ops/hexagon-npu.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
#include "graph.hpp"
#include "hexagon_npu.h"
#include "op_impl.hpp"
#include "op_registry.hpp"
#include "remote.h"
#include "tensor.hpp"
#include "thread_pool.hpp"

View File

@ -1,7 +1,7 @@
#include "graph.hpp"
#include "op_impl.hpp"
#include "op_registry.hpp"
#include "util.hpp"
#include "vtcm_mem.hpp"

View File

@ -1,18 +1,10 @@
#pragma once
#include "op_impl.hpp"
#include "op_flash_attn.hpp"
#include "op_glu.hpp"
#include "op_mul_mat.hpp"
#include "op_rope.hpp"
#include "op_types.hpp"
#include "type_traits.hpp"
#include "vec_ops.hpp"
#include <cmath>
#include <type_traits>
namespace {
namespace hexagon {
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
inline void vec_op_f32_f32(const float * src0, const float * src1, float * dst, size_t count) {
@ -41,6 +33,14 @@ inline void vec_op_f16_f16(const npu_device_fp16_t * src0,
vec_trans_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, dst, count);
}
template <HVX_Vector (*_OpUnaryTransform)(HVX_VectorPair)>
inline void unary_vec_op_f16_f32(const float * src, npu_device_fp16_t * dst, size_t count, size_t) {
// TODO: remove the unused param
using namespace hexagon::vec;
vec_trans_with_half_ret_impl<_OpUnaryTransform, float, npu_device_fp16_t>(src, dst, count);
}
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
// TODO: fix this since qf16 has less precision than fp16
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_VhfVhf(a, b));
@ -55,16 +55,25 @@ inline HVX_Vector vmul_f16_f16(HVX_Vector a, HVX_Vector b) {
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
}
inline HVX_Vector vequals_f16_f32(HVX_VectorPair a) {
const HVX_Vector kZeroV = Q6_V_vzero();
HVX_Vector lo = Q6_Vqf32_vadd_Vqf32Vsf(kZeroV, Q6_V_lo_W(a));
HVX_Vector hi = Q6_Vqf32_vadd_Vqf32Vsf(kZeroV, Q6_V_hi_W(a));
a = Q6_W_vcombine_VV(hi, lo);
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(a));
}
template <typename T> struct get_data_type {};
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t)> {
using type = _TyData;
};
template <typename _TyData, typename _TyParam>
struct get_data_type<void (*)(const _TyData *, _TyData *, size_t, _TyParam)> {
using type = _TyData;
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
template <typename _TyInput, typename _TyOutput, typename _TyParam>
struct get_data_type<void (*)(const _TyInput *, _TyOutput *, size_t, _TyParam)> {
using type = _TyInput;
using output_type = _TyOutput;
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
};
template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::compute_params * params) {
@ -280,8 +289,9 @@ void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
// TODO: merge with element_wise_op?
template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_params * params) {
using data_type = typename get_data_type<decltype(_RowFunc)>::type;
using param_type = typename get_data_type<decltype(_RowFunc)>::param_type;
using input_type = typename get_data_type<decltype(_RowFunc)>::type;
using output_type = typename get_data_type<decltype(_RowFunc)>::output_type;
using param_type = typename get_data_type<decltype(_RowFunc)>::param_type;
if (!out) {
return false;
@ -311,7 +321,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
const auto param = out->get_op_param<param_type>(0);
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(input_type);
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
const auto i03 = ir / rows_per_cube;
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
@ -323,7 +333,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
}
_RowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<data_type *>(dst_row),
_RowFunc(reinterpret_cast<const input_type *>(src0_row), reinterpret_cast<output_type *>(dst_row),
static_cast<size_t>(out->get_ne(0)), param);
}
@ -336,7 +346,7 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_RMS_NORM) {
if (op != NPU_OP_RMS_NORM && op != NPU_OP_CPY) {
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
return false;
}
@ -347,23 +357,38 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
}
const auto & src0 = srcs[0];
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32) {
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
return false;
}
if (!hexagon::is_same_shape(src0, *dst)) {
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
return false;
}
if (op == NPU_OP_RMS_NORM) {
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32) {
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
return false;
}
} else {
if (dst->nb[1] < dst->nb[0] || src0.nb[1] < src0.nb[0]) {
// TODO: support non-continuous row
DEVICE_LOG_DEBUG("[%s]unsupported non-continuous row\n", hexagon::op_get_name(op));
return false;
}
if (dst->type != NPU_DATA_TYPE_F16 || src0.type != NPU_DATA_TYPE_F32) {
// TODO: support more types
DEVICE_LOG_DEBUG("[%s]unsupported data type src:%s dst:%s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
return false;
}
}
return true;
}
@ -378,132 +403,4 @@ bool is_unary_op_required_sync(npu_device_tensor_op prev_op,
prev_op != NPU_OP_COUNT;
}
struct op_capabilities {
npu_device_tensor_op op;
hexagon::op_is_supported_func_type is_supported;
hexagon::op_required_sync_func_type requires_thread_barrier_func;
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
};
constexpr const op_capabilities kOpCapabilities[] = {
{
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
hexagon::is_mul_mat_required_sync,
{
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_ADD, is_element_wise_op_supported,
is_element_wise_op_required_sync, {
element_wise_op<vec_op_f32_f32<vadd_f32_f32>>, // NPU_DATA_TYPE_F32
element_wise_op<vec_op_f16_f16<vadd_f16_f16>>, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_SUB, is_element_wise_op_supported,
is_element_wise_op_required_sync, {
element_wise_op<vec_op_f32_f32<vsub_f32_f32>>, // NPU_DATA_TYPE_F32
element_wise_op<vec_op_f16_f16<vsub_f16_f16>>, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_MUL, is_element_wise_op_supported,
is_element_wise_op_required_sync, {
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_RMS_NORM, is_unary_op_supported,
is_unary_op_required_sync, {
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_FLASH_ATTN, hexagon::is_flash_attn_supported,
hexagon::is_flash_attn_required_sync,
{
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_ROPE, hexagon::is_rope_supported,
hexagon::is_rope_required_sync,
{
hexagon::rope_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_GLU, hexagon::is_glu_op_supported,
hexagon::is_glu_required_sync,
{
hexagon::glu_f32, // NPU_DATA_TYPE_F32
hexagon::glu_f16, // NPU_DATA_TYPE_F16
}, },
};
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
"kOpArray[NPU_OP_MUL_MAT] != mul_mat_f32");
static_assert(std::size(kOpCapabilities) == NPU_OP_COUNT);
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NPU_OP_MUL_MAT].op != NPU_OP_MUL_MAT");
static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL");
static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
if (op >= NPU_OP_COUNT) {
return nullptr;
}
return kOpCapabilities[op].compute_funcs[type];
}
} // namespace
namespace hexagon {
compute_func_type get_compute_func(tensor * dst) {
return get_compute_func_impl(dst->get_op(), dst->get_type());
}
bool requires_thread_barrier(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
if (op >= NPU_OP_COUNT) {
return false;
}
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne);
}
bool support_op(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
if (!op_spec) {
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
return false;
}
const auto op = op_spec->op;
auto is_supported_func = kOpCapabilities[op].is_supported;
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
return false;
}
if (get_compute_func_impl(op, dst->type) == nullptr) {
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
get_type_name(dst->type));
return false;
}
return true;
}
} // namespace hexagon

View File

@ -58,10 +58,10 @@ void flash_attn_impl(hexagon::tensor * out,
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const auto & k_type_traits = hexagon::get_type_traits(kKvDataType);
const auto q_to_vec_dot = k_type_traits.from_float;
const auto q_to_kv_type = k_type_traits.from_float;
constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16> :
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>;
if (!q_to_vec_dot) {
if (!q_to_kv_type) {
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
return;
}
@ -134,7 +134,7 @@ void flash_attn_impl(hexagon::tensor * out,
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
nullptr;
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
q_to_kv_type(reinterpret_cast<const float *>(q_data), Q_q, DK);
if (kHasMask) {
hexagon::l2fetch_row(reinterpret_cast<const uint8_t *>(mp), mask->get_nb(1));

View File

@ -48,8 +48,7 @@ inline void glu_vec_op_f32_f32(const float * src0,
size_t count,
hexagon::HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec;
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(src0, src1, dst, count,
coeff);
vec_trans_impl<hexagon::vec_swiglu_f32_f32, float, hexagon::HVX_VectorPair_x4>(src0, src1, dst, count, coeff);
}
template <auto _GluRowFunc, auto _CoeffLoadFunc>

View File

@ -0,0 +1,178 @@
#include "op_registry.hpp"
#include "op_eltwise.hpp"
#include "op_flash_attn.hpp"
#include "op_glu.hpp"
#include "op_mul_mat.hpp"
#include "op_rope.hpp"
#include "op_rows.hpp"
#include <cmath>
#include <cstddef>
#include <type_traits>
namespace {
struct op_capabilities {
npu_device_tensor_op op;
hexagon::op_is_supported_func_type is_supported;
hexagon::op_required_sync_func_type requires_thread_barrier_func;
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
};
constexpr const op_capabilities kOpCapabilities[] = {
{
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
hexagon::is_mul_mat_required_sync,
{
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_ADD, hexagon::is_element_wise_op_supported,
hexagon::is_element_wise_op_required_sync,
{
hexagon::element_wise_op<hexagon::vec_op_f32_f32<hexagon::vadd_f32_f32>>, // NPU_DATA_TYPE_F32
hexagon::element_wise_op<hexagon::vec_op_f16_f16<hexagon::vadd_f16_f16>>, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_SUB, hexagon::is_element_wise_op_supported,
hexagon::is_element_wise_op_required_sync,
{
hexagon::element_wise_op<hexagon::vec_op_f32_f32<hexagon::vsub_f32_f32>>, // NPU_DATA_TYPE_F32
hexagon::element_wise_op<hexagon::vec_op_f16_f16<hexagon::vsub_f16_f16>>, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_MUL, hexagon::is_element_wise_op_supported,
hexagon::is_element_wise_op_required_sync,
{
hexagon::element_wise_op<hexagon::vec_op_f32_f32<hexagon::vmul_f32_f32>>, // NPU_DATA_TYPE_F32
hexagon::element_wise_op<hexagon::vec_op_f16_f16<hexagon::vmul_f16_f16>>, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_RMS_NORM, hexagon::is_unary_op_supported,
hexagon::is_unary_op_required_sync,
{
hexagon::unary_op<hexagon::rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_FLASH_ATTN, hexagon::is_flash_attn_supported,
hexagon::is_flash_attn_required_sync,
{
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_ROPE, hexagon::is_rope_supported,
hexagon::is_rope_required_sync,
{
hexagon::rope_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_GLU, hexagon::is_glu_op_supported,
hexagon::is_glu_required_sync,
{
hexagon::glu_f32, // NPU_DATA_TYPE_F32
hexagon::glu_f16, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_GET_ROWS, hexagon::is_rows_supported,
hexagon::is_rows_required_sync,
{
hexagon::get_rows_f32, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, },
{
NPU_OP_SET_ROWS, hexagon::is_rows_supported,
hexagon::is_rows_required_sync,
{
hexagon::set_rows_generic, // NPU_DATA_TYPE_F32
hexagon::set_rows_generic, // NPU_DATA_TYPE_F16
nullptr, // NPU_DATA_TYPE_I32
nullptr, // NPU_DATA_TYPE_I64
hexagon::set_rows_generic, // NPU_DATA_TYPE_Q8_0
hexagon::set_rows_generic, // NPU_DATA_TYPE_Q4_0
nullptr, // TODO: figure out why failed on NPU_DATA_TYPE_Q4_K
}, },
{
NPU_OP_CPY, hexagon::is_unary_op_supported,
hexagon::is_unary_op_required_sync,
{
nullptr, // NPU_DATA_TYPE_F32
hexagon::unary_op<hexagon::unary_vec_op_f16_f32<hexagon::vequals_f16_f32>>, // NPU_DATA_TYPE_F16
}, },
};
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
"kOpArray[NPU_OP_MUL_MAT] != mul_mat_f32");
static_assert(std::size(kOpCapabilities) == NPU_OP_COUNT);
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NPU_OP_MUL_MAT].op != NPU_OP_MUL_MAT");
static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL");
static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
static_assert(kOpCapabilities[NPU_OP_GET_ROWS].op == NPU_OP_GET_ROWS,
"kOpArray[NPU_OP_GET_ROWS].op != NPU_OP_GET_ROWS");
static_assert(kOpCapabilities[NPU_OP_SET_ROWS].op == NPU_OP_SET_ROWS,
"kOpArray[NPU_OP_SET_ROWS].op != NPU_OP_SET_ROWS");
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
if (op >= NPU_OP_COUNT) {
return nullptr;
}
return kOpCapabilities[op].compute_funcs[type];
}
} // namespace
namespace hexagon {
compute_func_type get_compute_func(tensor * dst) {
return get_compute_func_impl(dst->get_op(), dst->get_type());
}
bool requires_thread_barrier(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
if (op >= NPU_OP_COUNT) {
return false;
}
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne);
}
bool support_op(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
if (!op_spec) {
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
return false;
}
const auto op = op_spec->op;
auto is_supported_func = kOpCapabilities[op].is_supported;
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
return false;
}
if (get_compute_func_impl(op, dst->type) == nullptr) {
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
get_type_name(dst->type));
return false;
}
return true;
}
} // namespace hexagon

View File

@ -0,0 +1,155 @@
#include "op_rows.hpp"
#include "type_traits.hpp"
namespace {
template <typename idx_t> void set_rows_impl(hexagon::tensor * out, hexagon::compute_params * params) {
auto * src0 = out->get_src(0);
auto * src1 = out->get_src(1);
const auto total_rows = src0->get_ne(3) * src0->get_ne(2) * src0->get_ne(1);
const auto start_end = params->get_work_slice(total_rows);
if (start_end.first >= start_end.second) {
return;
}
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("set_rows_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
hexagon::get_type_name(out->get_type()));
return;
}
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
const uint8_t * src1_ptr = src1->get_read_buffer(true); // TODO: avoid invalidation
const size_t rows_per_cube = src0->get_ne(2) * src0->get_ne(1);
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
auto from_float = hexagon::get_type_traits(out->get_type()).from_float;
for (size_t ir = start_end.first; ir < size_t(start_end.second); ++ir) {
const size_t i03 = ir / rows_per_cube;
const size_t i02 = ir / src0->get_ne(1) - i03 * src0->get_ne(2);
const size_t i01 = ir % src0->get_ne(1);
const size_t i12 = i03 % src1->get_ne(2);
const size_t i11 = i02 % src1->get_ne(1);
const size_t i10 = i01;
const size_t i1 = *reinterpret_cast<const idx_t *>(src1_ptr + i10 * src1->get_nb(0) + i11 * src1->get_nb(1) +
i12 * src1->get_nb(2));
from_float(reinterpret_cast<const float *>(src0_ptr + i01 * src0->get_nb(1) + i02 * src0->get_nb(2) +
i03 * src0->get_nb(3)),
dst_ptr + i1 * out->get_nb(1) + i02 * out->get_nb(2) + i03 * out->get_nb(3),
size_t(src0->get_ne(0)));
}
out->release_write_buffer(); // mark the output tensor as modified
}
} // namespace
namespace hexagon {
bool get_rows_f32(tensor * out, compute_params * params) {
// TODO: implement get_rows
return false;
}
bool set_rows_generic(tensor * out, compute_params * params) {
if (!out) {
return false;
}
auto * src0 = out->get_src(0);
auto * src1 = out->get_src(1);
if (!src0 || !src1) {
DEVICE_LOG_ERROR("set_rows_generic: missing src0 or src1\n");
return false;
}
switch (src1->get_type()) {
case NPU_DATA_TYPE_I32:
set_rows_impl<int32_t>(out, params);
break;
case NPU_DATA_TYPE_I64:
set_rows_impl<int64_t>(out, params);
break;
default:
DEVICE_LOG_ERROR("set_rows_generic: unsupported src1 type: %s\n", hexagon::get_type_name(src1->get_type()));
return false;
}
return true;
}
bool is_rows_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_GET_ROWS && op != NPU_OP_SET_ROWS) {
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
return false;
}
if (src_len < 2) {
DEVICE_LOG_DEBUG("[%s]invalid src_len: %zu\n", hexagon::op_get_name(op), src_len);
return false;
}
const auto & src0 = srcs[0];
const auto & src1 = srcs[1];
if (op == NPU_OP_GET_ROWS) {
if (dst->ne[0] != src0.ne[0]) {
DEVICE_LOG_DEBUG("[%s]dst.ne[0] and src0.ne[0] not match: %ld vs %ld\n", hexagon::op_get_name(op),
(long) dst->ne[0], (long) src0.ne[0]);
return false;
}
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]dst.type and src0.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type), hexagon::get_type_name(src0.type));
return false;
}
// TODO: remove this limitation
return false;
} else {
// NPU_OP_SET_ROWS
if (dst->ne[0] != src0.ne[0] || dst->ne[2] != src0.ne[2] || dst->ne[3] != src0.ne[3]) {
DEVICE_LOG_DEBUG("[%s]dst.ne[0], src0.ne[0] and src0.ne[2], src0.ne[3] not match: %ld vs %ld, %ld, %ld\n",
hexagon::op_get_name(op), (long) dst->ne[0], (long) src0.ne[0], (long) src0.ne[2],
(long) src0.ne[3]);
return false;
}
if (src0.type != NPU_DATA_TYPE_F32) {
DEVICE_LOG_DEBUG("[%s]src0.type is not F32: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type));
return false;
}
if (src1.type != NPU_DATA_TYPE_I32 && src1.type != NPU_DATA_TYPE_I64) {
DEVICE_LOG_DEBUG("[%s]src1.type is not I32 or I64: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src1.type));
return false;
}
if (dst->type != src0.type && !get_type_traits(dst->type).from_float) {
DEVICE_LOG_DEBUG("[%s]dst.from_float is null: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
return false;
}
}
return true;
}
bool is_rows_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
// TODO: implement is_rows_required_sync
return false;
}
} // namespace hexagon

View File

@ -0,0 +1,19 @@
#pragma once
#include "op_types.hpp"
namespace hexagon {
bool get_rows_f32(tensor * out, compute_params * params);
bool set_rows_generic(tensor * out, compute_params * params);
bool is_rows_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
bool is_rows_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
} // namespace hexagon

View File

@ -339,13 +339,13 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d
const int nb = count / qk;
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access
int i = 0;
hexagon::dequant_output_type * dst_ptr = dst;
int i = 0;
for (; i + 5 < nb; i += 6) {
auto qs = load_hexa_block_generic(src_ptr + i, qs_indices, scale_indices);
auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
auto res2 = dequantize_vec_q40_qf16_2blocks(qs.val[3], qs.val[4], table);
auto qs = load_hexa_block_generic(src_ptr + i, qs_indices, scale_indices);
auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
HVX_Vector block45 = Q6_V_vror_VR(qs.val[0], kSizeOfQs * 4);
auto res2 = dequantize_vec_q40_qf16_2blocks(block45, qs.val[3], table);
if constexpr (_IsDstAligned) {
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]);
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]);
@ -372,7 +372,8 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d
}
for (; i + 1 < nb; i += 2) {
auto res = load_dequant_vec_q40_qf16_2blocks(src_ptr + i, qs_indices, scale_indices, table);
auto qs = load_dual_block_generic(src_ptr + i, qs_indices, scale_indices);
auto res = dequantize_vec_q40_qf16_2blocks(qs.val[0], qs.val[1], table);
if constexpr (_IsDstAligned) {
*reinterpret_cast<HVX_Vector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(res);
} else {
@ -469,8 +470,14 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_output_type * dst, s
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qv, 4);
HVX_VectorPair qp = Q6_W_vshuff_VVR(q_hi, q_lo, kQuantSubBlockSize * 3);
dual_pair.p[0] = Q6_Wh_vlut16_VbVhR_nomatch(Q6_Vb_vshuff_Vb(Q6_V_lo_W(qp)), table, 0);
dual_pair.p[1] = Q6_Wh_vlut16_VbVhR_nomatch(Q6_Vb_vshuff_Vb(Q6_V_hi_W(qp)), table, 0);
q_lo = Q6_V_lo_W(qp);
q_hi = Q6_V_hi_W(qp);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
q_hi = Q6_Vb_vshuff_Vb(q_hi);
dual_pair.p[0] = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
dual_pair.p[1] = Q6_Wh_vlut16_VbVhR_nomatch(q_hi, table, 0);
const __fp16 d = reinterpret_cast<const __fp16 &>(src_ptr[i].d);
const __fp16 min = reinterpret_cast<const __fp16 &>(src_ptr[i].dmin);
@ -533,13 +540,14 @@ void copy_row_f16(const void * src, hexagon::dequant_output_type * dst, size_t c
hexagon::vec_cpy_f16(reinterpret_cast<const npu_device_fp16_t *>(src), dst, count);
}
void copy_row_f32(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector) {
template <typename _TSrc, typename _TDst, typename... _TExtArgs>
void copy_row_f32(const _TSrc * src, _TDst * dst, size_t count, _TExtArgs...) {
hexagon::vec_cpy_f32(reinterpret_cast<const float *>(src), reinterpret_cast<float *>(dst), count);
}
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr,
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32<void, hexagon::dequant_output_type, HVX_Vector>,
copy_row_f32<float, void>, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f32_f32>,
hexagon::type_erase_dot_func<hexagon::is_f32_f32_dot_product_aligned> },
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16,
@ -547,6 +555,7 @@ constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f16_f16>,
hexagon::type_erase_dot_func<hexagon::is_f16_f16_dot_product_aligned> },
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false },
{ NPU_DATA_TYPE_I64, "I64", 1, sizeof(int64_t), false },
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
quantize_row_q8_0 },
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
@ -563,6 +572,8 @@ static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16,
"kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum");
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_I32].type == NPU_DATA_TYPE_I32,
"kDeviceTypeTraits I32 type mismatch with npu_device_tensor_data_type enum");
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_I64].type == NPU_DATA_TYPE_I64,
"kDeviceTypeTraits I64 type mismatch with npu_device_tensor_data_type enum");
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0,
"kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum");
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0,

View File

@ -72,6 +72,12 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) {
return "ROPE";
case NPU_OP_GLU:
return "GLU";
case NPU_OP_GET_ROWS:
return "GET_ROWS";
case NPU_OP_SET_ROWS:
return "SET_ROWS";
case NPU_OP_CPY:
return "CPY";
default:
return "UNKNOWN";
}
@ -283,68 +289,44 @@ template <size_t _buffer_count> class npu_scoped_timer {
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration,
_sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count,
(unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix,
(unsigned long long) _sub_proc_data[2].proc_count,
(unsigned long long) sub_proc2_duration,
_sub_proc_data[3].log_prefix,
(unsigned long long) _sub_proc_data[3].proc_count,
(unsigned long long) sub_proc3_duration);
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
(unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix,
(unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration);
break;
case 3:
DEVICE_LOG_WARN(
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration,
_sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count,
(unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix,
(unsigned long long) _sub_proc_data[2].proc_count,
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
(unsigned long long) sub_proc2_duration);
break;
case 2:
DEVICE_LOG_WARN(
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration,
_sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count,
(unsigned long long) sub_proc1_duration);
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration);
break;
case 1:
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix,
(unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration);
break;
default:
case 0:
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration);
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix,
(unsigned long long) total_pcycles, (unsigned long long) duration);
break;
}
}
@ -372,8 +354,8 @@ template <size_t _buffer_count, size_t _sub_idx> class npu_sub_process_scoped_ti
}
~npu_sub_process_scoped_timer() {
_timer.add_sub_proc_cycles(
_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, HAP_perf_get_pcycles() - _begin_pcycles);
_timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles,
HAP_perf_get_pcycles() - _begin_pcycles);
}
private:

View File

@ -20,6 +20,7 @@ using HVX_Vector_x3 = HEXAGON_pack<HVX_Vector, 3>;
using HVX_Vector_x4 = HEXAGON_pack<HVX_Vector, 4>;
using HVX_Vector_x5 = HEXAGON_pack<HVX_Vector, 5>;
using HVX_VectorPair_x2 = HEXAGON_pack<HVX_VectorPair, 2>;
using HVX_VectorPair_x3 = HEXAGON_pack<HVX_VectorPair, 3>;
using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>;
using HVX_VectorPred_x3 = HEXAGON_pack<HVX_VectorPred, 3>;
@ -363,9 +364,10 @@ inline HVX_Vector vec_dot_product_vqf32_q40_f32(const npu_device_block_q4_0 * sr
using namespace hexagon::vec::math;
using namespace hexagon::vec::quant;
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask<npu_device_block_q4_0>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices =
make_qs_load_mask<npu_device_block_q4_0, q4_qs_shuff_idx>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
make_scale_load_mask<npu_device_block_q4_0>();
Q6_Vh_vshuff_Vh(make_scale_load_mask<npu_device_block_q4_0>());
return vec_dot_product_quant_impl<npu_device_block_q4_0, float, HVX_Vector, load_dequant_vec_q40_qf32_4blocks,
load_dequant_vec_q40_qf32_2blocks, load_dequant_vec_q40_qf32_1block,

View File

@ -380,6 +380,31 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
return _ReduceFunc(_AddFunc(sum0, sum1));
}
inline HVX_Vector_x2 vec_dot_accum_pair(HVX_VectorPair s0,
HVX_Vector curr10,
HVX_Vector curr11,
HVX_Vector prev1,
HVX_Vector_x2 sums,
size_t offset,
HVX_Vector zero) {
HVX_Vector l0 = Q6_V_lo_W(s0);
HVX_Vector l1 = Q6_V_valign_VVR(curr10, prev1, offset);
HVX_Vector h0 = Q6_V_hi_W(s0);
HVX_Vector h1 = Q6_V_valign_VVR(curr11, curr10, offset);
l1 = Q6_Vqf32_vadd_VsfVsf(zero, l1);
h1 = Q6_Vqf32_vadd_VsfVsf(zero, h1);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l0, l1);
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(h0, h1);
HVX_Vector_x2 result;
result.val[0] = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sums.val[0]);
result.val[1] = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sums.val[1]);
return result;
}
template <typename _TQuantElem0,
typename _TElem1,
typename _TRet,
@ -422,8 +447,7 @@ inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
HVX_Vector sum = kZeroV;
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
HVX_Vector sum0 = kZeroV;
HVX_Vector sum1 = kZeroV;
HVX_Vector_x2 sums = { kZeroV, kZeroV };
while (src1_vec_ptr_end - src1_vec_ptr > 3) {
HVX_VectorPair_x2 s01 = _DequantQuadFunc(src0_ptr, qs_indices, scale_indices, table);
@ -432,38 +456,10 @@ inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
HVX_Vector curr110 = src1_vec_ptr[2];
HVX_Vector curr111 = src1_vec_ptr[3];
HVX_Vector l00 = Q6_V_lo_W(s01.val[0]);
HVX_Vector l10 = Q6_V_valign_VVR(curr100, prev1, (size_t) src1);
HVX_Vector l01 = Q6_V_lo_W(s01.val[1]);
HVX_Vector l11 = Q6_V_valign_VVR(curr110, curr101, (size_t) src1);
HVX_Vector h00 = Q6_V_hi_W(s01.val[0]);
HVX_Vector h10 = Q6_V_valign_VVR(curr101, curr100, (size_t) src1);
HVX_Vector h01 = Q6_V_hi_W(s01.val[1]);
HVX_Vector h11 = Q6_V_valign_VVR(curr111, curr110, (size_t) src1);
l10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l10);
l11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l11);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l00, l10);
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(l01, l11);
h10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h10);
h11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h11);
HVX_Vector mpy2 = Q6_Vqf32_vmpy_Vqf32Vqf32(h00, h10);
HVX_Vector mpy3 = Q6_Vqf32_vmpy_Vqf32Vqf32(h01, h11);
sums = vec_dot_accum_pair(s01.val[0], curr100, curr101, prev1, sums, (size_t) src1, kZeroV);
sums = vec_dot_accum_pair(s01.val[1], curr110, curr111, curr101, sums, (size_t) src1, kZeroV);
prev1 = curr111;
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0);
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1);
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy2, sum0);
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy3, sum1);
src0_ptr += 4;
src1_vec_ptr += 4;
}
@ -473,28 +469,14 @@ inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
HVX_Vector curr10 = src1_vec_ptr[0];
HVX_Vector curr11 = src1_vec_ptr[1];
HVX_Vector l0 = Q6_V_lo_W(s0);
HVX_Vector l1 = Q6_V_valign_VVR(curr10, prev1, (size_t) src1);
HVX_Vector h0 = Q6_V_hi_W(s0);
HVX_Vector h1 = Q6_V_valign_VVR(curr11, curr10, (size_t) src1);
l1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l1);
h1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h1);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l0, l1);
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(h0, h1);
sums = vec_dot_accum_pair(s0, curr10, curr11, prev1, sums, (size_t) src1, kZeroV);
prev1 = curr11;
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0);
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1);
src0_ptr += 2;
src1_vec_ptr += 2;
}
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum0, sum1);
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sums.val[0], sums.val[1]);
}
if (src1_vec_ptr_end - src1_vec_ptr > 0) {
@ -614,8 +596,16 @@ template <typename _TData> inline void vec_zero_impl(_TData * src, size_t count)
}
}
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector), typename _TyData>
inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData * dst, size_t count) {
template <auto * _OpBinaryTransform, typename _TyData, typename... _TyParams>
inline void vec_trans_impl(const _TyData * src0,
const _TyData * src1,
_TyData * dst,
size_t count,
_TyParams... params) {
static_assert(std::is_same_v<decltype(_OpBinaryTransform), HVX_Vector (*)(HVX_Vector, HVX_Vector, _TyParams...)>,
"Function type mismatch: _OpBinaryTransform must be of type HVX_Vector (*)(HVX_Vector, HVX_Vector, "
"_TyParams...)");
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
@ -632,11 +622,11 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1);
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, params...);
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1);
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, params...);
prev0 = Q6_V_hi_W(curr0);
prev1 = Q6_V_hi_W(curr1);
@ -653,7 +643,7 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
HVX_Vector curr1 = *src1_vec_ptr++;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, params...);
prev0 = curr0;
prev1 = curr1;
@ -675,7 +665,7 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, params...);
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
@ -697,98 +687,98 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
prev1;
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1));
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, params...));
}
}
template <typename _TyData, typename _TyParam, HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector, _TyParam)>
inline void vec_trans_with_param_impl(const _TyData * src0,
const _TyData * src1,
_TyData * dst,
size_t count,
_TyParam param) {
template <auto * _OpUnaryTransform, typename _TyData, typename _TyDataRet, typename... _TyParams>
inline void vec_trans_with_half_ret_impl(const _TyData * src0, _TyDataRet * dst, size_t count, _TyParams... params) {
static_assert(std::is_same_v<decltype(_OpUnaryTransform), HVX_Vector (*)(HVX_VectorPair, _TyParams...)>,
"Function type mismatch: _OpUnaryTransform must be of type HVX_Vector (*)(HVX_Vector, HVX_Vector, "
"_TyParams...)");
static_assert(sizeof(_TyData) / sizeof(_TyDataRet) == 2,
"Element size mismatch: _TyData must be twice the size of _TyDataRet");
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
const HVX_Vector kZero = Q6_V_vzero();
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
HVX_Vector prev0 = *src0_vec_ptr++;
HVX_Vector prev1 = *src1_vec_ptr++;
{
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, param);
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, param);
dst_vec_ptr[0] = _OpUnaryTransform(Q6_W_vcombine_VV(h0, l0), params...);
prev0 = Q6_V_hi_W(curr0);
prev1 = Q6_V_hi_W(curr1);
src0_vec_ptr += 2;
src1_vec_ptr += 2;
dst_vec_ptr += 2;
dst_vec_ptr++;
}
}
HVX_Vector result;
uint32_t processed_bytes = 0;
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
HVX_Vector curr0 = *src0_vec_ptr++;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector curr1 = *src1_vec_ptr++;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
prev0 = curr0;
prev1 = curr1;
dst_vec_ptr++;
prev0 = curr0;
result = _OpUnaryTransform(Q6_W_vcombine_VV(kZero, s0), params...);
processed_bytes = kElementsPerVector * sizeof(_TyDataRet);
}
const size_t leftover = count % kElementsPerVector;
static const HVX_VectorPred mask = Q6_Q_vsetq_R(hexagon::kBytesPerVector / 2);
const size_t src_leftover = count % kElementsPerVector;
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
// handle the last vector
// see also:
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
bool should_fetch_src0 = src_leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
if (processed_bytes) {
s0 = _OpUnaryTransform(Q6_W_vcombine_VV(s0, kZero), params...);
dst_vec_ptr[0] = Q6_V_vmux_QVV(mask, result, s0); // only update the lower half of the result vector
dst_vec_ptr++;
} else {
result = _OpUnaryTransform(Q6_W_vcombine_VV(kZero, s0), params...);
}
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
prev0 = curr0;
prev1 = curr1;
dst_vec_ptr++;
processed_bytes += kElementsPerVector * sizeof(_TyDataRet);
}
if (leftover > 0) {
if (src_leftover > 0) {
// handle the leftover elements
const size_t leftover_bytes = leftover * sizeof(_TyData);
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
const size_t src_leftover_bytes = src_leftover * sizeof(_TyData);
HVX_Vector curr0 = (src_leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
*src0_vec_ptr :
prev0;
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
*src1_vec_ptr :
prev1;
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
if (processed_bytes % hexagon::kBytesPerVector) {
curr0 = _OpUnaryTransform(Q6_W_vcombine_VV(curr0, kZero), params...);
curr0 = Q6_V_vmux_QVV(mask, result, curr0);
} else {
curr0 = _OpUnaryTransform(Q6_W_vcombine_VV(kZero, curr0), params...);
}
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, param));
processed_bytes += src_leftover * sizeof(_TyDataRet);
q6op_vstu_variable_ARV(dst_vec_ptr, processed_bytes % hexagon::kBytesPerVector, curr0);
} else if (processed_bytes % hexagon::kBytesPerVector) {
// TODO: This conditional write-back is suboptimal because it may result in an extra memory write.
q6op_vstu_variable_ARV(dst_vec_ptr, processed_bytes % hexagon::kBytesPerVector, result);
}
}

View File

@ -47,7 +47,26 @@ template <typename _TBlock> inline HVX_Vector make_scale_load_mask() {
return ret.v;
}
template <typename _TBlock> inline HVX_Vector make_qs_load_mask() {
inline size_t default_qs_shuff_idx(size_t idx) {
return idx;
}
inline size_t q4_qs_shuff_idx(size_t idx) {
// TODO: The current mask (kIndexShuffle) is hardcoded for the Q4 quantization block layout, where data is arranged in a specific interleaved pattern.
// A more general solution would need to programmatically generate the shuffle mask based on the quantization block's structure.
constexpr const size_t kIndexShuffle[] = {
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 2, 6, 10, 14, 18, 22,
26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45,
49, 53, 57, 61, 3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 127, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
};
return kIndexShuffle[idx];
}
template <typename _TBlock, size_t (*_FuncGetShuffIdx)(size_t) = default_qs_shuff_idx>
inline HVX_Vector make_qs_load_mask() {
static_assert(sizeof(_TBlock) < hexagon::kBytesPerVector, "wrong block size/padding");
const size_t qs_start_offset = offsetof(_TBlock, qs);
@ -58,23 +77,15 @@ template <typename _TBlock> inline HVX_Vector make_qs_load_mask() {
for (size_t i = 0; i < hexagon::kBytesPerVector; ++i) {
auto offset = i % sizeof(_TBlock);
if (offset >= qs_start_offset && offset < qs_end_offset) {
ret.u8[ret_idx++] = (i & 1) ? (i / 2 + 64) : (i / 2);
size_t idx = _FuncGetShuffIdx(ret_idx);
ret.u8[idx] = ((i & 1) ? (i / 2 + 64) : (i / 2));
ret_idx++;
}
}
return ret.v;
}
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock * srcs, HVX_VectorPred mask) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale);
return Q6_V_vmux_QVV(mask, blocks, block1);
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
const HVX_Vector qs_indices,
@ -84,10 +95,10 @@ inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs);
HVX_Vector block01 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 2);
HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 2);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
if constexpr (sizeof(_TBlock) * 4 > hexagon::kBytesPerVector) {
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 1);
@ -139,15 +150,14 @@ inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * srcs,
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
inline hexagon::HVX_Vector_x4 load_hexa_block_generic(const _TBlock * srcs,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs);
hexagon::HVX_Vector_x5 result;
hexagon::HVX_Vector_x4 result;
{
HVX_Vector block012345 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 1);
@ -155,7 +165,6 @@ inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 3);
result.val[0] = block012345;
result.val[3] = Q6_V_vror_VR(block012345, kSizeOfQs * 4); // block45
}
{
@ -173,7 +182,7 @@ inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
result.val[1] = scale01;
result.val[2] = scale23;
result.val[4] = scale45;
result.val[3] = scale45;
}
return result;
@ -198,15 +207,12 @@ inline HVX_VectorPair dequantize_vec_q40_qf32_2blocks(HVX_Vector qs, HVX_Vector
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * 4);
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
scale01 = Q6_Vh_vshuff_Vh(scale01);
q_lo = Q6_Vh_vshuff_Vh(q_lo); // TODO: avoid vshuff here
q_lo = Q6_V_lo_W(qp0);
return Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
}
@ -247,27 +253,49 @@ inline HVX_VectorPair_x2 dequantize_vec_q40_qf32_4blocks(HVX_Vector qs,
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * 4);
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
q_lo = Q6_Vh_vshuff_Vh(q_lo);
scale01 = Q6_Vh_vshuff_Vh(scale01);
q_hi = Q6_Vh_vshuff_Vh(q_hi);
scale23 = Q6_Vh_vshuff_Vh(scale23); // TODO: avoid vshuff here
hexagon::HVX_VectorPair_x2 result;
result.val[0] = Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
result.val[1] = Q6_Wqf32_vmpy_VhfVhf(q_hi, scale23);
return result;
}
inline HVX_VectorPair_x3 dequantize_vec_q40_qf32_6blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector scale23,
HVX_Vector scale45,
HVX_Vector table) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * 4);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
HVX_VectorPair qp1 = Q6_Wh_vlut16_VbVhR_nomatch(q_hi, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
HVX_Vector q2 = Q6_V_lo_W(qp1);
hexagon::HVX_VectorPair_x3 result;
result.val[0] = Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
result.val[1] = Q6_Wqf32_vmpy_VhfVhf(q_hi, scale23);
result.val[2] = Q6_Wqf32_vmpy_VhfVhf(q2, scale45);
return result;
}
inline HVX_Vector load_dequant_vec_q40_qf32_1block(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
@ -277,14 +305,6 @@ inline HVX_Vector load_dequant_vec_q40_qf32_1block(const npu_device_block_q4_0 *
return Q6_V_lo_W(dequantize_vec_q40_qf32_2blocks(qs.val[0], qs.val[1], table));
}
inline HVX_Vector load_dequant_vec_q40_qf16_2blocks(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
auto qs = load_dual_block_generic(src, qs_indices, scale_indices);
return dequantize_vec_q40_qf16_2blocks(qs.val[0], qs.val[1], table);
}
inline HVX_VectorPair load_dequant_vec_q40_qf32_2blocks(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
@ -301,4 +321,12 @@ inline HVX_VectorPair_x2 load_dequant_vec_q40_qf32_4blocks(const npu_device_bloc
return dequantize_vec_q40_qf32_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
}
inline HVX_VectorPair_x3 load_dequant_vec_q40_qf32_6blocks(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
auto qs = load_hexa_block_generic(src, qs_indices, scale_indices);
return dequantize_vec_q40_qf32_6blocks(qs.val[0], qs.val[1], qs.val[2], qs.val[3], table);
}
} // namespace hexagon::vec::quant

View File

@ -67,6 +67,8 @@ void backend_buffer_set_tensor(ggml_backend_buffer_t buffer,
size_t size) {
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_set_tensor.size.%zu",
(void *) get_buffer_object(buffer), size);
// TODO: use DMA instead of memcpy?
memcpy((char *) tensor->data + offset, data, size);
}
@ -76,12 +78,15 @@ void backend_buffer_get_tensor(ggml_backend_buffer_t buffer,
size_t offset,
size_t size) {
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_get_tensor", (void *) get_buffer_object(buffer));
// TODO: use DMA instead of memcpy?
memcpy(data, (const char *) tensor->data + offset, size);
}
bool backend_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_cpy_tensor", (void *) get_buffer_object(buffer));
if (ggml_backend_buffer_is_host(src->buffer)) {
// TODO: use DMA instead of memcpy?
memcpy(dst->data, src->data, ggml_nbytes(src));
return true;
}

View File

@ -37,6 +37,12 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
return NPU_OP_ROPE;
case GGML_OP_GLU:
return NPU_OP_GLU;
case GGML_OP_GET_ROWS:
return NPU_OP_GET_ROWS;
case GGML_OP_SET_ROWS:
return NPU_OP_SET_ROWS;
case GGML_OP_CPY:
return NPU_OP_CPY;
default:
return NPU_OP_COUNT;
}
@ -60,6 +66,12 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) {
return ggml_op_name(GGML_OP_ROPE);
case NPU_OP_GLU:
return ggml_op_name(GGML_OP_GLU);
case NPU_OP_GET_ROWS:
return ggml_op_name(GGML_OP_GET_ROWS);
case NPU_OP_SET_ROWS:
return ggml_op_name(GGML_OP_SET_ROWS);
case NPU_OP_CPY:
return ggml_op_name(GGML_OP_CPY);
default:
return "UNKNOWN";
}
@ -73,6 +85,8 @@ enum npu_device_tensor_data_type type_to_npu_type(ggml_type type) {
return NPU_DATA_TYPE_F16;
case GGML_TYPE_I32:
return NPU_DATA_TYPE_I32;
case GGML_TYPE_I64:
return NPU_DATA_TYPE_I64;
case GGML_TYPE_Q4_K:
return NPU_DATA_TYPE_Q4_K;
case GGML_TYPE_Q4_0:
@ -178,30 +192,15 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
switch (dims) {
default:
case 4:
snprintf(out,
max_len,
"%s[%ldx%ldx%ldx%ld]",
ggml_type_name(tensor->type),
(long) tensor->ne[0],
(long) tensor->ne[1],
(long) tensor->ne[2],
(long) tensor->ne[3]);
snprintf(out, max_len, "%s[%ldx%ldx%ldx%ld]", ggml_type_name(tensor->type), (long) tensor->ne[0],
(long) tensor->ne[1], (long) tensor->ne[2], (long) tensor->ne[3]);
break;
case 3:
snprintf(out,
max_len,
"%s[%ldx%ldx%ld]",
ggml_type_name(tensor->type),
(long) tensor->ne[0],
(long) tensor->ne[1],
(long) tensor->ne[2]);
snprintf(out, max_len, "%s[%ldx%ldx%ld]", ggml_type_name(tensor->type), (long) tensor->ne[0],
(long) tensor->ne[1], (long) tensor->ne[2]);
break;
case 2:
snprintf(out,
max_len,
"%s[%ldx%ld]",
ggml_type_name(tensor->type),
(long) tensor->ne[0],
snprintf(out, max_len, "%s[%ldx%ld]", ggml_type_name(tensor->type), (long) tensor->ne[0],
(long) tensor->ne[1]);
break;
case 1:
@ -233,14 +232,8 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
print_tensor(dst->src[2], src2_desc, sizeof(src2_desc));
char src3_desc[256];
print_tensor(dst->src[3], src3_desc, sizeof(src3_desc));
snprintf(out,
max_len,
"dst: %s, src0: %s, src1: %s, src2: %s, src3: %s",
dst_desc,
src0_desc,
src1_desc,
src2_desc,
src3_desc);
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s, src3: %s", dst_desc, src0_desc,
src1_desc, src2_desc, src3_desc);
return;
}
case 3:
@ -251,8 +244,8 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
print_tensor(dst->src[1], src1_desc, sizeof(src1_desc));
char src2_desc[256];
print_tensor(dst->src[2], src2_desc, sizeof(src2_desc));
snprintf(
out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s", dst_desc, src0_desc, src1_desc, src2_desc);
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s", dst_desc, src0_desc, src1_desc,
src2_desc);
return;
}
case 2:

View File

@ -54,6 +54,9 @@ interface npu_device : remote_handle64{
NPU_OP_FLASH_ATTN,
NPU_OP_ROPE,
NPU_OP_GLU,
NPU_OP_GET_ROWS,
NPU_OP_SET_ROWS,
NPU_OP_CPY,
NPU_OP_COUNT
};
@ -70,6 +73,7 @@ interface npu_device : remote_handle64{
NPU_DATA_TYPE_F32,
NPU_DATA_TYPE_F16,
NPU_DATA_TYPE_I32,
NPU_DATA_TYPE_I64,
NPU_DATA_TYPE_Q8_0,
NPU_DATA_TYPE_Q4_0,
NPU_DATA_TYPE_Q4_K,