feat: perf opt add set rows (#59)
* Add power management utilities to NPU device context and update DCVS settings * Update DCVS settings in power_utils to use v3 API and enhance power management * wip * Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance * use lut * wip * fix test failure * wip * Refactor load_qual_block_generic to improve block handling and optimize vector operations * Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling * Refactor flash_attn_impl to optimize mask l2 prefetch * wip * wip * wip * wip * add log * link against shared libraries instead of static ones * fix swiglu * wip * refactor expf_fix to handle overflow for different data types * enhance is_glu_op_supported to validate shapes for multiple sources * wip * refactor logging macros to use hexagon namespace and improve formatting * fix printf format error * wip * refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias * rename * feat: enhance fa with mask * wip * wip * refactor: replace instances of Q6_V_vzero() with kZeroV for consistency * wip * wip * wip * fix: improve address alignment check in HVX_Vector handling * refactor: streamline vector dot product implementations for improved readability * refactor: q4k add hvx intrinsic impl * refactor: enhance dequantize_row_q4_K for clarity and performance * refactor: optimize scale mask usage in dequantization functions for improved performance * refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements * refactor: move GLU operation implementation into separated file * sync after swiglu * wip * wip * wip * feat: increase prc main thread stack size * fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant * wip * feat: add optimized vector operations for exponential and division with overflow handling * wip * feat: refactor exponential function to handle overflow and underflow with improved logic * wip * wip * feat: add vector loading and scaling functions for improved performance in block processing * wip * feat: optimize block loading by refactoring scale index handling for improved performance * use Q6_Vb_vlut32_VbVbR_nomatch instead * feat: enhance scale loading by adding static assertion and restructuring block handling * wip * feat: refactor vec_dot_product_mixed_impl for improved clarity and performance * wip * feat: simplify vector loading functions and improve alignment handling * wip * feat: enhance scale loading mask with quantization block size validation * wip * feat: implement make_scale_load_mask function and refactor vector handling in vec_ops * feat: enhance load_dual_block_generic to include scale indices for improved vector loading * revert q8 dequant * wip * feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods * wip * wip * add qurt_mutex * Add DMA transfer class and integrate into thread pool * Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel * fix dma crash * fix failed unit tests * wip * use alignas * Improve DMA transfer error handling and update descriptor completion check * Fix VTCM cache size calculation in element-wise operations * Add cache clean operations before DMA transfers in element-wise operations * reduce cache clean operations * Refactor DMA transfer functions to support 1D operations and rename for clarity * Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization * Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations * wip * Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic * fix 2d dma * feat: add DMA plane cache * rename * wip * use memcpy for debug * fix cache plane calc * refactor: remove debug logging from mul_mat_impl and optimize cache handling * rename * fix 2d dma type * refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions * refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions * wip * wip * move op impl into sub dir * add log * fix: correct pointer usage in mul_mat_gemv_impl for next plane access * fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl * fix: fix crash by using the entire row bytes * wip * wip * fix: prevent parallelization for scalar src1 in is_mul_mat_supported * fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary * wip * fix: enable thread barrier for mul multiplication operations * feat: add synchronization checks for tensor operations and update related functions * wip * fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations * Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations" This reverts commit af3441e67e706b2e5122369dc160353796867dd3. * wip * wip * add comment * fix: improve DMA transfer handling in mul_mat_gemv_impl for quantized source tensors * add log * try fix mulmat gemv * wip * fix: enhance DMA transfer handling in mul_mat_gemv_impl for quantized source tensors * fix: optimize cache offset calculation and remove redundant swap in mul_mat_gemv_impl * fix: refactor DMA transfer handling in mul_mat_gemv_impl for improved clarity and maintainability * wip * wip * wip * fix: enhance mul_mat_impl for improved cache handling and clarity * fix: refactor tensor unflattening and DMA transfer initialization for improved clarity and type safety * fix: improve cache handling of quant * wip * fix: improve cache handling in mul_mat_impl and mul_mat_gemv_impl for better memory efficiency * rename * add load_hexa_block_generic * wip * extract dequant block into separated function * refactor: enhance dequantization functions with table parameter * fix load_dual_block_generic * refactor: rename dequantization functions for clarity and enhance block handling * refactor: simplify dequantization logic by consolidating block handling and removing unused parameters * wip * wip * feat: add make_qs_load_mask function and update load_dual_block_generic to use qs_indices * fix load_dual_block_generic * refactor: update load functions to use qs_indices for improved block loading * wip * fix: update loop indices and boundary checks to use size_t for better efficiency * wip * update make_scale_load_mask, to make it available for q8 * feat: add vec_dot_product_quant_impl for quantized dot product computation * refactoring: move come quant func to dedicated file * refactor: rename dequantization functions for clarity and consistency * wip * feat: enhance vec_dot_product_quant_impl with dual dequantization and improved assertions * add vec_dot_product_vqf32_q40_f32 * wip * wip * wip * wip * implement vec_mpy_qf32_qf32_qf32 function and update vec_dot_product_vqf32_q40_f32 to use it * wip * add src0_plane_write_cache_offset * wip * enhance mul_mat_f32 to handle NPU_DATA_TYPE_Q4_0 for quantized matrix multiplication * wip * wip * update test func * refactor mul_mat_gemv_quant_impl to use get_nb for row stride and remove unused test function in init_f16_f32_table * wip * Add support for 4-block dequantization in vec_quant and update dot product implementation * Refactor vec_dot_product_quant_impl to improve variable handling and enhance readability * Refactor vec_dot_product_quant_impl to replace template function with inline vector operations * use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32 * Revert "use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32" This reverts commit 54839166fddbe40a0392adee5863c59070ccdbe4. * wip * improve log print in graph * Refactor batched_row_dot to accept additional arguments and remove batched_row_dot_with_table * Refactor synchronization functions to include previous operation and NE type parameters * Refactor synchronization checks in several operations * Update synchronization checks to include NPU_OP_COUNT in required conditions * Add performance tracking to buffer management functions * add memset * add log * fix: update backend device type from ACCEL to IGPU * fix comment * add get/set rows * feat: implement row operation support checks in is_rows_supported * feat: add support for I64 data type in rows operations * feat: implement set_rows functionality for I32 and I64 data types * wip * fix set_rows * feat: extend is_rows_supported to allow F32 data type in destination * wip * feat: rename set_rows function, add generic to its name * disable q4_k * move ops to separated file * rename: op_impl -> op_registry * refactor: update get_data_type struct to include output type for unary operations * refactor: simplify vec_trans_impl by removing parameterized overload and using variadic templates * add vec_trans_with_half_ret_impl * add NPU_OP_CPY * refactor: enhance is_unary_op_supported to handle non-continuous rows and add type support logging * refactor: update vec_trans_with_half_ret_impl to use processed_bytes for clarity and accuracy * wip * refactor: optimize dequantize_vec_q40_qf32_4blocks by improving shuffling logic and reducing redundancy * refactor: improve performance of vec_dot_product and dequantize functions by optimizing shuffling logic * wip * add dequantize_vec_q40_qf32_6blocks * feat: add load_dequant_vec_q40_qf32_6blocks function for 6-block dequantization * feat: enhance vec_dot_product_quant_impl with 6-element processing loop for improved performance * Revert "feat: enhance vec_dot_product_quant_impl with 6-element processing loop for improved performance" This reverts commit a5c8fa3e4d9a2d89c8c0821c936c0466e0af7869. since there's a performance degradation * fix: correct load_hexa_block_generic return type and update dequantization logic * wip * wip * feat: add make_q40_qs_load_mask function and update vec_dot_product_vqf32_q40_f32 * fix dequant load * add debug log * wip * wip * fix shuffle index array * refactor: simplify load mask generation and improve index shuffling for q4 blocks * wip * wip * fix comment * wip * update ops.md * update ops.md by create_ops_docs.py # Conflicts: # docs/ops.md
This commit is contained in:
parent
38ae191c55
commit
e6a5f7baa6
204
docs/ops.md
204
docs/ops.md
|
|
@ -12,105 +12,105 @@ Legend:
|
|||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | qualcomm | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
#include "graph.hpp"
|
||||
#include "hexagon_npu.h"
|
||||
#include "op_impl.hpp"
|
||||
#include "op_registry.hpp"
|
||||
#include "remote.h"
|
||||
#include "tensor.hpp"
|
||||
#include "thread_pool.hpp"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
#include "graph.hpp"
|
||||
|
||||
#include "op_impl.hpp"
|
||||
#include "op_registry.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
|
||||
#include "op_impl.hpp"
|
||||
|
||||
#include "op_flash_attn.hpp"
|
||||
#include "op_glu.hpp"
|
||||
#include "op_mul_mat.hpp"
|
||||
#include "op_rope.hpp"
|
||||
#include "op_types.hpp"
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <type_traits>
|
||||
|
||||
namespace {
|
||||
namespace hexagon {
|
||||
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
|
||||
inline void vec_op_f32_f32(const float * src0, const float * src1, float * dst, size_t count) {
|
||||
|
|
@ -41,6 +33,14 @@ inline void vec_op_f16_f16(const npu_device_fp16_t * src0,
|
|||
vec_trans_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, dst, count);
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_OpUnaryTransform)(HVX_VectorPair)>
|
||||
inline void unary_vec_op_f16_f32(const float * src, npu_device_fp16_t * dst, size_t count, size_t) {
|
||||
// TODO: remove the unused param
|
||||
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_with_half_ret_impl<_OpUnaryTransform, float, npu_device_fp16_t>(src, dst, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
|
||||
// TODO: fix this since qf16 has less precision than fp16
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_VhfVhf(a, b));
|
||||
|
|
@ -55,16 +55,25 @@ inline HVX_Vector vmul_f16_f16(HVX_Vector a, HVX_Vector b) {
|
|||
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
|
||||
}
|
||||
|
||||
inline HVX_Vector vequals_f16_f32(HVX_VectorPair a) {
|
||||
const HVX_Vector kZeroV = Q6_V_vzero();
|
||||
HVX_Vector lo = Q6_Vqf32_vadd_Vqf32Vsf(kZeroV, Q6_V_lo_W(a));
|
||||
HVX_Vector hi = Q6_Vqf32_vadd_Vqf32Vsf(kZeroV, Q6_V_hi_W(a));
|
||||
a = Q6_W_vcombine_VV(hi, lo);
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(a));
|
||||
}
|
||||
|
||||
template <typename T> struct get_data_type {};
|
||||
|
||||
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t)> {
|
||||
using type = _TyData;
|
||||
};
|
||||
|
||||
template <typename _TyData, typename _TyParam>
|
||||
struct get_data_type<void (*)(const _TyData *, _TyData *, size_t, _TyParam)> {
|
||||
using type = _TyData;
|
||||
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
|
||||
template <typename _TyInput, typename _TyOutput, typename _TyParam>
|
||||
struct get_data_type<void (*)(const _TyInput *, _TyOutput *, size_t, _TyParam)> {
|
||||
using type = _TyInput;
|
||||
using output_type = _TyOutput;
|
||||
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
|
||||
};
|
||||
|
||||
template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::compute_params * params) {
|
||||
|
|
@ -280,8 +289,9 @@ void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
|
|||
|
||||
// TODO: merge with element_wise_op?
|
||||
template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_params * params) {
|
||||
using data_type = typename get_data_type<decltype(_RowFunc)>::type;
|
||||
using param_type = typename get_data_type<decltype(_RowFunc)>::param_type;
|
||||
using input_type = typename get_data_type<decltype(_RowFunc)>::type;
|
||||
using output_type = typename get_data_type<decltype(_RowFunc)>::output_type;
|
||||
using param_type = typename get_data_type<decltype(_RowFunc)>::param_type;
|
||||
|
||||
if (!out) {
|
||||
return false;
|
||||
|
|
@ -311,7 +321,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
|
||||
const auto param = out->get_op_param<param_type>(0);
|
||||
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
|
||||
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(input_type);
|
||||
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
|
||||
const auto i03 = ir / rows_per_cube;
|
||||
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
|
||||
|
|
@ -323,7 +333,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
|
||||
}
|
||||
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<data_type *>(dst_row),
|
||||
_RowFunc(reinterpret_cast<const input_type *>(src0_row), reinterpret_cast<output_type *>(dst_row),
|
||||
static_cast<size_t>(out->get_ne(0)), param);
|
||||
}
|
||||
|
||||
|
|
@ -336,7 +346,7 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_RMS_NORM) {
|
||||
if (op != NPU_OP_RMS_NORM && op != NPU_OP_CPY) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -347,23 +357,38 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op == NPU_OP_RMS_NORM) {
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (dst->nb[1] < dst->nb[0] || src0.nb[1] < src0.nb[0]) {
|
||||
// TODO: support non-continuous row
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported non-continuous row\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F16 || src0.type != NPU_DATA_TYPE_F32) {
|
||||
// TODO: support more types
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type src:%s dst:%s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -378,132 +403,4 @@ bool is_unary_op_required_sync(npu_device_tensor_op prev_op,
|
|||
prev_op != NPU_OP_COUNT;
|
||||
}
|
||||
|
||||
struct op_capabilities {
|
||||
npu_device_tensor_op op;
|
||||
hexagon::op_is_supported_func_type is_supported;
|
||||
hexagon::op_required_sync_func_type requires_thread_barrier_func;
|
||||
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
|
||||
};
|
||||
|
||||
constexpr const op_capabilities kOpCapabilities[] = {
|
||||
{
|
||||
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
|
||||
hexagon::is_mul_mat_required_sync,
|
||||
{
|
||||
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_ADD, is_element_wise_op_supported,
|
||||
is_element_wise_op_required_sync, {
|
||||
element_wise_op<vec_op_f32_f32<vadd_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vadd_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_SUB, is_element_wise_op_supported,
|
||||
is_element_wise_op_required_sync, {
|
||||
element_wise_op<vec_op_f32_f32<vsub_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vsub_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_MUL, is_element_wise_op_supported,
|
||||
is_element_wise_op_required_sync, {
|
||||
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_RMS_NORM, is_unary_op_supported,
|
||||
is_unary_op_required_sync, {
|
||||
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_FLASH_ATTN, hexagon::is_flash_attn_supported,
|
||||
hexagon::is_flash_attn_required_sync,
|
||||
{
|
||||
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_ROPE, hexagon::is_rope_supported,
|
||||
hexagon::is_rope_required_sync,
|
||||
{
|
||||
hexagon::rope_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_GLU, hexagon::is_glu_op_supported,
|
||||
hexagon::is_glu_required_sync,
|
||||
{
|
||||
hexagon::glu_f32, // NPU_DATA_TYPE_F32
|
||||
hexagon::glu_f16, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
};
|
||||
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
|
||||
"kOpArray[NPU_OP_MUL_MAT] != mul_mat_f32");
|
||||
|
||||
static_assert(std::size(kOpCapabilities) == NPU_OP_COUNT);
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NPU_OP_MUL_MAT].op != NPU_OP_MUL_MAT");
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL");
|
||||
static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
|
||||
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
|
||||
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
|
||||
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
|
||||
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
|
||||
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
|
||||
|
||||
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return kOpCapabilities[op].compute_funcs[type];
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
compute_func_type get_compute_func(tensor * dst) {
|
||||
return get_compute_func_impl(dst->get_op(), dst->get_type());
|
||||
}
|
||||
|
||||
bool requires_thread_barrier(npu_device_tensor_op prev_op,
|
||||
const npu_device_ne_type & prev_ne,
|
||||
npu_device_tensor_op op,
|
||||
const npu_device_ne_type & ne) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
|
||||
return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne);
|
||||
}
|
||||
|
||||
bool support_op(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
if (!op_spec) {
|
||||
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto op = op_spec->op;
|
||||
auto is_supported_func = kOpCapabilities[op].is_supported;
|
||||
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
|
||||
get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -58,10 +58,10 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const auto & k_type_traits = hexagon::get_type_traits(kKvDataType);
|
||||
const auto q_to_vec_dot = k_type_traits.from_float;
|
||||
const auto q_to_kv_type = k_type_traits.from_float;
|
||||
constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16> :
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>;
|
||||
if (!q_to_vec_dot) {
|
||||
if (!q_to_kv_type) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
|
||||
return;
|
||||
}
|
||||
|
|
@ -134,7 +134,7 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
|
||||
nullptr;
|
||||
|
||||
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
|
||||
q_to_kv_type(reinterpret_cast<const float *>(q_data), Q_q, DK);
|
||||
|
||||
if (kHasMask) {
|
||||
hexagon::l2fetch_row(reinterpret_cast<const uint8_t *>(mp), mask->get_nb(1));
|
||||
|
|
|
|||
|
|
@ -48,8 +48,7 @@ inline void glu_vec_op_f32_f32(const float * src0,
|
|||
size_t count,
|
||||
hexagon::HVX_VectorPair_x4 coeff) {
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(src0, src1, dst, count,
|
||||
coeff);
|
||||
vec_trans_impl<hexagon::vec_swiglu_f32_f32, float, hexagon::HVX_VectorPair_x4>(src0, src1, dst, count, coeff);
|
||||
}
|
||||
|
||||
template <auto _GluRowFunc, auto _CoeffLoadFunc>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,178 @@
|
|||
#include "op_registry.hpp"
|
||||
|
||||
#include "op_eltwise.hpp"
|
||||
#include "op_flash_attn.hpp"
|
||||
#include "op_glu.hpp"
|
||||
#include "op_mul_mat.hpp"
|
||||
#include "op_rope.hpp"
|
||||
#include "op_rows.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
|
||||
namespace {
|
||||
|
||||
struct op_capabilities {
|
||||
npu_device_tensor_op op;
|
||||
hexagon::op_is_supported_func_type is_supported;
|
||||
hexagon::op_required_sync_func_type requires_thread_barrier_func;
|
||||
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
|
||||
};
|
||||
|
||||
constexpr const op_capabilities kOpCapabilities[] = {
|
||||
{
|
||||
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
|
||||
hexagon::is_mul_mat_required_sync,
|
||||
{
|
||||
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_ADD, hexagon::is_element_wise_op_supported,
|
||||
hexagon::is_element_wise_op_required_sync,
|
||||
{
|
||||
hexagon::element_wise_op<hexagon::vec_op_f32_f32<hexagon::vadd_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
hexagon::element_wise_op<hexagon::vec_op_f16_f16<hexagon::vadd_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_SUB, hexagon::is_element_wise_op_supported,
|
||||
hexagon::is_element_wise_op_required_sync,
|
||||
{
|
||||
hexagon::element_wise_op<hexagon::vec_op_f32_f32<hexagon::vsub_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
hexagon::element_wise_op<hexagon::vec_op_f16_f16<hexagon::vsub_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_MUL, hexagon::is_element_wise_op_supported,
|
||||
hexagon::is_element_wise_op_required_sync,
|
||||
{
|
||||
hexagon::element_wise_op<hexagon::vec_op_f32_f32<hexagon::vmul_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
hexagon::element_wise_op<hexagon::vec_op_f16_f16<hexagon::vmul_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_RMS_NORM, hexagon::is_unary_op_supported,
|
||||
hexagon::is_unary_op_required_sync,
|
||||
{
|
||||
hexagon::unary_op<hexagon::rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_FLASH_ATTN, hexagon::is_flash_attn_supported,
|
||||
hexagon::is_flash_attn_required_sync,
|
||||
{
|
||||
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_ROPE, hexagon::is_rope_supported,
|
||||
hexagon::is_rope_required_sync,
|
||||
{
|
||||
hexagon::rope_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_GLU, hexagon::is_glu_op_supported,
|
||||
hexagon::is_glu_required_sync,
|
||||
{
|
||||
hexagon::glu_f32, // NPU_DATA_TYPE_F32
|
||||
hexagon::glu_f16, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_GET_ROWS, hexagon::is_rows_supported,
|
||||
hexagon::is_rows_required_sync,
|
||||
{
|
||||
hexagon::get_rows_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
{
|
||||
NPU_OP_SET_ROWS, hexagon::is_rows_supported,
|
||||
hexagon::is_rows_required_sync,
|
||||
{
|
||||
hexagon::set_rows_generic, // NPU_DATA_TYPE_F32
|
||||
hexagon::set_rows_generic, // NPU_DATA_TYPE_F16
|
||||
nullptr, // NPU_DATA_TYPE_I32
|
||||
nullptr, // NPU_DATA_TYPE_I64
|
||||
hexagon::set_rows_generic, // NPU_DATA_TYPE_Q8_0
|
||||
hexagon::set_rows_generic, // NPU_DATA_TYPE_Q4_0
|
||||
nullptr, // TODO: figure out why failed on NPU_DATA_TYPE_Q4_K
|
||||
}, },
|
||||
{
|
||||
NPU_OP_CPY, hexagon::is_unary_op_supported,
|
||||
hexagon::is_unary_op_required_sync,
|
||||
{
|
||||
nullptr, // NPU_DATA_TYPE_F32
|
||||
hexagon::unary_op<hexagon::unary_vec_op_f16_f32<hexagon::vequals_f16_f32>>, // NPU_DATA_TYPE_F16
|
||||
}, },
|
||||
};
|
||||
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
|
||||
"kOpArray[NPU_OP_MUL_MAT] != mul_mat_f32");
|
||||
|
||||
static_assert(std::size(kOpCapabilities) == NPU_OP_COUNT);
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NPU_OP_MUL_MAT].op != NPU_OP_MUL_MAT");
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL");
|
||||
static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
|
||||
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
|
||||
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
|
||||
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
|
||||
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
|
||||
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
|
||||
static_assert(kOpCapabilities[NPU_OP_GET_ROWS].op == NPU_OP_GET_ROWS,
|
||||
"kOpArray[NPU_OP_GET_ROWS].op != NPU_OP_GET_ROWS");
|
||||
static_assert(kOpCapabilities[NPU_OP_SET_ROWS].op == NPU_OP_SET_ROWS,
|
||||
"kOpArray[NPU_OP_SET_ROWS].op != NPU_OP_SET_ROWS");
|
||||
|
||||
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return kOpCapabilities[op].compute_funcs[type];
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
compute_func_type get_compute_func(tensor * dst) {
|
||||
return get_compute_func_impl(dst->get_op(), dst->get_type());
|
||||
}
|
||||
|
||||
bool requires_thread_barrier(npu_device_tensor_op prev_op,
|
||||
const npu_device_ne_type & prev_ne,
|
||||
npu_device_tensor_op op,
|
||||
const npu_device_ne_type & ne) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
|
||||
return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne);
|
||||
}
|
||||
|
||||
bool support_op(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
if (!op_spec) {
|
||||
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto op = op_spec->op;
|
||||
auto is_supported_func = kOpCapabilities[op].is_supported;
|
||||
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
|
||||
get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
#include "op_rows.hpp"
|
||||
|
||||
#include "type_traits.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename idx_t> void set_rows_impl(hexagon::tensor * out, hexagon::compute_params * params) {
|
||||
auto * src0 = out->get_src(0);
|
||||
auto * src1 = out->get_src(1);
|
||||
|
||||
const auto total_rows = src0->get_ne(3) * src0->get_ne(2) * src0->get_ne(1);
|
||||
const auto start_end = params->get_work_slice(total_rows);
|
||||
if (start_end.first >= start_end.second) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("set_rows_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer(true); // TODO: avoid invalidation
|
||||
const size_t rows_per_cube = src0->get_ne(2) * src0->get_ne(1);
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
auto from_float = hexagon::get_type_traits(out->get_type()).from_float;
|
||||
for (size_t ir = start_end.first; ir < size_t(start_end.second); ++ir) {
|
||||
const size_t i03 = ir / rows_per_cube;
|
||||
const size_t i02 = ir / src0->get_ne(1) - i03 * src0->get_ne(2);
|
||||
const size_t i01 = ir % src0->get_ne(1);
|
||||
const size_t i12 = i03 % src1->get_ne(2);
|
||||
const size_t i11 = i02 % src1->get_ne(1);
|
||||
const size_t i10 = i01;
|
||||
|
||||
const size_t i1 = *reinterpret_cast<const idx_t *>(src1_ptr + i10 * src1->get_nb(0) + i11 * src1->get_nb(1) +
|
||||
i12 * src1->get_nb(2));
|
||||
from_float(reinterpret_cast<const float *>(src0_ptr + i01 * src0->get_nb(1) + i02 * src0->get_nb(2) +
|
||||
i03 * src0->get_nb(3)),
|
||||
dst_ptr + i1 * out->get_nb(1) + i02 * out->get_nb(2) + i03 * out->get_nb(3),
|
||||
size_t(src0->get_ne(0)));
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool get_rows_f32(tensor * out, compute_params * params) {
|
||||
// TODO: implement get_rows
|
||||
return false;
|
||||
}
|
||||
|
||||
bool set_rows_generic(tensor * out, compute_params * params) {
|
||||
if (!out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto * src0 = out->get_src(0);
|
||||
auto * src1 = out->get_src(1);
|
||||
if (!src0 || !src1) {
|
||||
DEVICE_LOG_ERROR("set_rows_generic: missing src0 or src1\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_I32:
|
||||
set_rows_impl<int32_t>(out, params);
|
||||
break;
|
||||
case NPU_DATA_TYPE_I64:
|
||||
set_rows_impl<int64_t>(out, params);
|
||||
break;
|
||||
default:
|
||||
DEVICE_LOG_ERROR("set_rows_generic: unsupported src1 type: %s\n", hexagon::get_type_name(src1->get_type()));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_rows_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_GET_ROWS && op != NPU_OP_SET_ROWS) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_len < 2) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid src_len: %zu\n", hexagon::op_get_name(op), src_len);
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (op == NPU_OP_GET_ROWS) {
|
||||
if (dst->ne[0] != src0.ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst.ne[0] and src0.ne[0] not match: %ld vs %ld\n", hexagon::op_get_name(op),
|
||||
(long) dst->ne[0], (long) src0.ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst.type and src0.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type), hexagon::get_type_name(src0.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: remove this limitation
|
||||
return false;
|
||||
} else {
|
||||
// NPU_OP_SET_ROWS
|
||||
if (dst->ne[0] != src0.ne[0] || dst->ne[2] != src0.ne[2] || dst->ne[3] != src0.ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst.ne[0], src0.ne[0] and src0.ne[2], src0.ne[3] not match: %ld vs %ld, %ld, %ld\n",
|
||||
hexagon::op_get_name(op), (long) dst->ne[0], (long) src0.ne[0], (long) src0.ne[2],
|
||||
(long) src0.ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type is not F32: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.type != NPU_DATA_TYPE_I32 && src1.type != NPU_DATA_TYPE_I64) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1.type is not I32 or I64: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != src0.type && !get_type_traits(dst->type).from_float) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst.from_float is null: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_rows_required_sync(npu_device_tensor_op prev_op,
|
||||
const npu_device_ne_type & prev_ne,
|
||||
npu_device_tensor_op op,
|
||||
const npu_device_ne_type & ne) {
|
||||
// TODO: implement is_rows_required_sync
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include "op_types.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool get_rows_f32(tensor * out, compute_params * params);
|
||||
bool set_rows_generic(tensor * out, compute_params * params);
|
||||
|
||||
bool is_rows_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool is_rows_required_sync(npu_device_tensor_op prev_op,
|
||||
const npu_device_ne_type & prev_ne,
|
||||
npu_device_tensor_op op,
|
||||
const npu_device_ne_type & ne);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -339,13 +339,13 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d
|
|||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
|
||||
|
||||
hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access
|
||||
|
||||
int i = 0;
|
||||
hexagon::dequant_output_type * dst_ptr = dst;
|
||||
int i = 0;
|
||||
for (; i + 5 < nb; i += 6) {
|
||||
auto qs = load_hexa_block_generic(src_ptr + i, qs_indices, scale_indices);
|
||||
auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
|
||||
auto res2 = dequantize_vec_q40_qf16_2blocks(qs.val[3], qs.val[4], table);
|
||||
auto qs = load_hexa_block_generic(src_ptr + i, qs_indices, scale_indices);
|
||||
auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
|
||||
HVX_Vector block45 = Q6_V_vror_VR(qs.val[0], kSizeOfQs * 4);
|
||||
auto res2 = dequantize_vec_q40_qf16_2blocks(block45, qs.val[3], table);
|
||||
if constexpr (_IsDstAligned) {
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]);
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]);
|
||||
|
|
@ -372,7 +372,8 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d
|
|||
}
|
||||
|
||||
for (; i + 1 < nb; i += 2) {
|
||||
auto res = load_dequant_vec_q40_qf16_2blocks(src_ptr + i, qs_indices, scale_indices, table);
|
||||
auto qs = load_dual_block_generic(src_ptr + i, qs_indices, scale_indices);
|
||||
auto res = dequantize_vec_q40_qf16_2blocks(qs.val[0], qs.val[1], table);
|
||||
if constexpr (_IsDstAligned) {
|
||||
*reinterpret_cast<HVX_Vector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(res);
|
||||
} else {
|
||||
|
|
@ -469,8 +470,14 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_output_type * dst, s
|
|||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qv, 4);
|
||||
HVX_VectorPair qp = Q6_W_vshuff_VVR(q_hi, q_lo, kQuantSubBlockSize * 3);
|
||||
|
||||
dual_pair.p[0] = Q6_Wh_vlut16_VbVhR_nomatch(Q6_Vb_vshuff_Vb(Q6_V_lo_W(qp)), table, 0);
|
||||
dual_pair.p[1] = Q6_Wh_vlut16_VbVhR_nomatch(Q6_Vb_vshuff_Vb(Q6_V_hi_W(qp)), table, 0);
|
||||
q_lo = Q6_V_lo_W(qp);
|
||||
q_hi = Q6_V_hi_W(qp);
|
||||
|
||||
q_lo = Q6_Vb_vshuff_Vb(q_lo);
|
||||
q_hi = Q6_Vb_vshuff_Vb(q_hi);
|
||||
|
||||
dual_pair.p[0] = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
|
||||
dual_pair.p[1] = Q6_Wh_vlut16_VbVhR_nomatch(q_hi, table, 0);
|
||||
|
||||
const __fp16 d = reinterpret_cast<const __fp16 &>(src_ptr[i].d);
|
||||
const __fp16 min = reinterpret_cast<const __fp16 &>(src_ptr[i].dmin);
|
||||
|
|
@ -533,13 +540,14 @@ void copy_row_f16(const void * src, hexagon::dequant_output_type * dst, size_t c
|
|||
hexagon::vec_cpy_f16(reinterpret_cast<const npu_device_fp16_t *>(src), dst, count);
|
||||
}
|
||||
|
||||
void copy_row_f32(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector) {
|
||||
template <typename _TSrc, typename _TDst, typename... _TExtArgs>
|
||||
void copy_row_f32(const _TSrc * src, _TDst * dst, size_t count, _TExtArgs...) {
|
||||
hexagon::vec_cpy_f32(reinterpret_cast<const float *>(src), reinterpret_cast<float *>(dst), count);
|
||||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32<void, hexagon::dequant_output_type, HVX_Vector>,
|
||||
copy_row_f32<float, void>, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f32_f32_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16,
|
||||
|
|
@ -547,6 +555,7 @@ constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
|||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f16_f16_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false },
|
||||
{ NPU_DATA_TYPE_I64, "I64", 1, sizeof(int64_t), false },
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
|
||||
quantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
|
||||
|
|
@ -563,6 +572,8 @@ static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16,
|
|||
"kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_I32].type == NPU_DATA_TYPE_I32,
|
||||
"kDeviceTypeTraits I32 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_I64].type == NPU_DATA_TYPE_I64,
|
||||
"kDeviceTypeTraits I64 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0,
|
||||
"kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,12 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) {
|
|||
return "ROPE";
|
||||
case NPU_OP_GLU:
|
||||
return "GLU";
|
||||
case NPU_OP_GET_ROWS:
|
||||
return "GET_ROWS";
|
||||
case NPU_OP_SET_ROWS:
|
||||
return "SET_ROWS";
|
||||
case NPU_OP_CPY:
|
||||
return "CPY";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
@ -283,68 +289,44 @@ template <size_t _buffer_count> class npu_scoped_timer {
|
|||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration,
|
||||
_sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count,
|
||||
(unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration,
|
||||
_sub_proc_data[3].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[3].proc_count,
|
||||
(unsigned long long) sub_proc3_duration);
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration);
|
||||
break;
|
||||
case 3:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration,
|
||||
_sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count,
|
||||
(unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[2].proc_count,
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration);
|
||||
break;
|
||||
case 2:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration,
|
||||
_sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count,
|
||||
(unsigned long long) sub_proc1_duration);
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration);
|
||||
break;
|
||||
case 1:
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix,
|
||||
(unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration);
|
||||
break;
|
||||
default:
|
||||
case 0:
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration);
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix,
|
||||
(unsigned long long) total_pcycles, (unsigned long long) duration);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -372,8 +354,8 @@ template <size_t _buffer_count, size_t _sub_idx> class npu_sub_process_scoped_ti
|
|||
}
|
||||
|
||||
~npu_sub_process_scoped_timer() {
|
||||
_timer.add_sub_proc_cycles(
|
||||
_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, HAP_perf_get_pcycles() - _begin_pcycles);
|
||||
_timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles,
|
||||
HAP_perf_get_pcycles() - _begin_pcycles);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ using HVX_Vector_x3 = HEXAGON_pack<HVX_Vector, 3>;
|
|||
using HVX_Vector_x4 = HEXAGON_pack<HVX_Vector, 4>;
|
||||
using HVX_Vector_x5 = HEXAGON_pack<HVX_Vector, 5>;
|
||||
using HVX_VectorPair_x2 = HEXAGON_pack<HVX_VectorPair, 2>;
|
||||
using HVX_VectorPair_x3 = HEXAGON_pack<HVX_VectorPair, 3>;
|
||||
using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>;
|
||||
using HVX_VectorPred_x3 = HEXAGON_pack<HVX_VectorPred, 3>;
|
||||
|
||||
|
|
@ -363,9 +364,10 @@ inline HVX_Vector vec_dot_product_vqf32_q40_f32(const npu_device_block_q4_0 * sr
|
|||
using namespace hexagon::vec::math;
|
||||
using namespace hexagon::vec::quant;
|
||||
|
||||
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask<npu_device_block_q4_0>();
|
||||
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices =
|
||||
make_qs_load_mask<npu_device_block_q4_0, q4_qs_shuff_idx>();
|
||||
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
|
||||
make_scale_load_mask<npu_device_block_q4_0>();
|
||||
Q6_Vh_vshuff_Vh(make_scale_load_mask<npu_device_block_q4_0>());
|
||||
|
||||
return vec_dot_product_quant_impl<npu_device_block_q4_0, float, HVX_Vector, load_dequant_vec_q40_qf32_4blocks,
|
||||
load_dequant_vec_q40_qf32_2blocks, load_dequant_vec_q40_qf32_1block,
|
||||
|
|
|
|||
|
|
@ -380,6 +380,31 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
|
|||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
inline HVX_Vector_x2 vec_dot_accum_pair(HVX_VectorPair s0,
|
||||
HVX_Vector curr10,
|
||||
HVX_Vector curr11,
|
||||
HVX_Vector prev1,
|
||||
HVX_Vector_x2 sums,
|
||||
size_t offset,
|
||||
HVX_Vector zero) {
|
||||
HVX_Vector l0 = Q6_V_lo_W(s0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(curr10, prev1, offset);
|
||||
|
||||
HVX_Vector h0 = Q6_V_hi_W(s0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(curr11, curr10, offset);
|
||||
|
||||
l1 = Q6_Vqf32_vadd_VsfVsf(zero, l1);
|
||||
h1 = Q6_Vqf32_vadd_VsfVsf(zero, h1);
|
||||
|
||||
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l0, l1);
|
||||
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(h0, h1);
|
||||
|
||||
HVX_Vector_x2 result;
|
||||
result.val[0] = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sums.val[0]);
|
||||
result.val[1] = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sums.val[1]);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename _TQuantElem0,
|
||||
typename _TElem1,
|
||||
typename _TRet,
|
||||
|
|
@ -422,8 +447,7 @@ inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
|
|||
HVX_Vector sum = kZeroV;
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector sum0 = kZeroV;
|
||||
HVX_Vector sum1 = kZeroV;
|
||||
HVX_Vector_x2 sums = { kZeroV, kZeroV };
|
||||
|
||||
while (src1_vec_ptr_end - src1_vec_ptr > 3) {
|
||||
HVX_VectorPair_x2 s01 = _DequantQuadFunc(src0_ptr, qs_indices, scale_indices, table);
|
||||
|
|
@ -432,38 +456,10 @@ inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
|
|||
HVX_Vector curr110 = src1_vec_ptr[2];
|
||||
HVX_Vector curr111 = src1_vec_ptr[3];
|
||||
|
||||
HVX_Vector l00 = Q6_V_lo_W(s01.val[0]);
|
||||
HVX_Vector l10 = Q6_V_valign_VVR(curr100, prev1, (size_t) src1);
|
||||
|
||||
HVX_Vector l01 = Q6_V_lo_W(s01.val[1]);
|
||||
HVX_Vector l11 = Q6_V_valign_VVR(curr110, curr101, (size_t) src1);
|
||||
|
||||
HVX_Vector h00 = Q6_V_hi_W(s01.val[0]);
|
||||
HVX_Vector h10 = Q6_V_valign_VVR(curr101, curr100, (size_t) src1);
|
||||
|
||||
HVX_Vector h01 = Q6_V_hi_W(s01.val[1]);
|
||||
HVX_Vector h11 = Q6_V_valign_VVR(curr111, curr110, (size_t) src1);
|
||||
|
||||
l10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l10);
|
||||
l11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l11);
|
||||
|
||||
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l00, l10);
|
||||
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(l01, l11);
|
||||
|
||||
h10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h10);
|
||||
h11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h11);
|
||||
|
||||
HVX_Vector mpy2 = Q6_Vqf32_vmpy_Vqf32Vqf32(h00, h10);
|
||||
HVX_Vector mpy3 = Q6_Vqf32_vmpy_Vqf32Vqf32(h01, h11);
|
||||
|
||||
sums = vec_dot_accum_pair(s01.val[0], curr100, curr101, prev1, sums, (size_t) src1, kZeroV);
|
||||
sums = vec_dot_accum_pair(s01.val[1], curr110, curr111, curr101, sums, (size_t) src1, kZeroV);
|
||||
prev1 = curr111;
|
||||
|
||||
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0);
|
||||
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1);
|
||||
|
||||
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy2, sum0);
|
||||
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy3, sum1);
|
||||
|
||||
src0_ptr += 4;
|
||||
src1_vec_ptr += 4;
|
||||
}
|
||||
|
|
@ -473,28 +469,14 @@ inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
|
|||
HVX_Vector curr10 = src1_vec_ptr[0];
|
||||
HVX_Vector curr11 = src1_vec_ptr[1];
|
||||
|
||||
HVX_Vector l0 = Q6_V_lo_W(s0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(curr10, prev1, (size_t) src1);
|
||||
|
||||
HVX_Vector h0 = Q6_V_hi_W(s0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(curr11, curr10, (size_t) src1);
|
||||
|
||||
l1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l1);
|
||||
h1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h1);
|
||||
|
||||
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l0, l1);
|
||||
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(h0, h1);
|
||||
|
||||
sums = vec_dot_accum_pair(s0, curr10, curr11, prev1, sums, (size_t) src1, kZeroV);
|
||||
prev1 = curr11;
|
||||
|
||||
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0);
|
||||
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1);
|
||||
|
||||
src0_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
}
|
||||
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum0, sum1);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sums.val[0], sums.val[1]);
|
||||
}
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 0) {
|
||||
|
|
@ -614,8 +596,16 @@ template <typename _TData> inline void vec_zero_impl(_TData * src, size_t count)
|
|||
}
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector), typename _TyData>
|
||||
inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData * dst, size_t count) {
|
||||
template <auto * _OpBinaryTransform, typename _TyData, typename... _TyParams>
|
||||
inline void vec_trans_impl(const _TyData * src0,
|
||||
const _TyData * src1,
|
||||
_TyData * dst,
|
||||
size_t count,
|
||||
_TyParams... params) {
|
||||
static_assert(std::is_same_v<decltype(_OpBinaryTransform), HVX_Vector (*)(HVX_Vector, HVX_Vector, _TyParams...)>,
|
||||
"Function type mismatch: _OpBinaryTransform must be of type HVX_Vector (*)(HVX_Vector, HVX_Vector, "
|
||||
"_TyParams...)");
|
||||
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
|
|
@ -632,11 +622,11 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
|
|||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, params...);
|
||||
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1);
|
||||
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, params...);
|
||||
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
|
|
@ -653,7 +643,7 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
|
|||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, params...);
|
||||
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
|
@ -675,7 +665,7 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
|
|||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, params...);
|
||||
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
|
|
@ -697,98 +687,98 @@ inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData *
|
|||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1));
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, params...));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename _TyData, typename _TyParam, HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector, _TyParam)>
|
||||
inline void vec_trans_with_param_impl(const _TyData * src0,
|
||||
const _TyData * src1,
|
||||
_TyData * dst,
|
||||
size_t count,
|
||||
_TyParam param) {
|
||||
template <auto * _OpUnaryTransform, typename _TyData, typename _TyDataRet, typename... _TyParams>
|
||||
inline void vec_trans_with_half_ret_impl(const _TyData * src0, _TyDataRet * dst, size_t count, _TyParams... params) {
|
||||
static_assert(std::is_same_v<decltype(_OpUnaryTransform), HVX_Vector (*)(HVX_VectorPair, _TyParams...)>,
|
||||
"Function type mismatch: _OpUnaryTransform must be of type HVX_Vector (*)(HVX_Vector, HVX_Vector, "
|
||||
"_TyParams...)");
|
||||
|
||||
static_assert(sizeof(_TyData) / sizeof(_TyDataRet) == 2,
|
||||
"Element size mismatch: _TyData must be twice the size of _TyDataRet");
|
||||
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
const HVX_Vector kZero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
|
||||
{
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, param);
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, param);
|
||||
dst_vec_ptr[0] = _OpUnaryTransform(Q6_W_vcombine_VV(h0, l0), params...);
|
||||
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
dst_vec_ptr += 2;
|
||||
dst_vec_ptr++;
|
||||
}
|
||||
}
|
||||
|
||||
HVX_Vector result;
|
||||
uint32_t processed_bytes = 0;
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
|
||||
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
dst_vec_ptr++;
|
||||
prev0 = curr0;
|
||||
result = _OpUnaryTransform(Q6_W_vcombine_VV(kZero, s0), params...);
|
||||
processed_bytes = kElementsPerVector * sizeof(_TyDataRet);
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
static const HVX_VectorPred mask = Q6_Q_vsetq_R(hexagon::kBytesPerVector / 2);
|
||||
|
||||
const size_t src_leftover = count % kElementsPerVector;
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
bool should_fetch_src0 = src_leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
|
||||
if (processed_bytes) {
|
||||
s0 = _OpUnaryTransform(Q6_W_vcombine_VV(s0, kZero), params...);
|
||||
dst_vec_ptr[0] = Q6_V_vmux_QVV(mask, result, s0); // only update the lower half of the result vector
|
||||
dst_vec_ptr++;
|
||||
} else {
|
||||
result = _OpUnaryTransform(Q6_W_vcombine_VV(kZero, s0), params...);
|
||||
}
|
||||
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
dst_vec_ptr++;
|
||||
processed_bytes += kElementsPerVector * sizeof(_TyDataRet);
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
if (src_leftover > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover_bytes = leftover * sizeof(_TyData);
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
const size_t src_leftover_bytes = src_leftover * sizeof(_TyData);
|
||||
HVX_Vector curr0 = (src_leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
if (processed_bytes % hexagon::kBytesPerVector) {
|
||||
curr0 = _OpUnaryTransform(Q6_W_vcombine_VV(curr0, kZero), params...);
|
||||
curr0 = Q6_V_vmux_QVV(mask, result, curr0);
|
||||
} else {
|
||||
curr0 = _OpUnaryTransform(Q6_W_vcombine_VV(kZero, curr0), params...);
|
||||
}
|
||||
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, param));
|
||||
processed_bytes += src_leftover * sizeof(_TyDataRet);
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, processed_bytes % hexagon::kBytesPerVector, curr0);
|
||||
} else if (processed_bytes % hexagon::kBytesPerVector) {
|
||||
// TODO: This conditional write-back is suboptimal because it may result in an extra memory write.
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, processed_bytes % hexagon::kBytesPerVector, result);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,26 @@ template <typename _TBlock> inline HVX_Vector make_scale_load_mask() {
|
|||
return ret.v;
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector make_qs_load_mask() {
|
||||
inline size_t default_qs_shuff_idx(size_t idx) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
inline size_t q4_qs_shuff_idx(size_t idx) {
|
||||
// TODO: The current mask (kIndexShuffle) is hardcoded for the Q4 quantization block layout, where data is arranged in a specific interleaved pattern.
|
||||
// A more general solution would need to programmatically generate the shuffle mask based on the quantization block's structure.
|
||||
constexpr const size_t kIndexShuffle[] = {
|
||||
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 2, 6, 10, 14, 18, 22,
|
||||
26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45,
|
||||
49, 53, 57, 61, 3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 127, 127,
|
||||
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
|
||||
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
|
||||
127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127,
|
||||
};
|
||||
return kIndexShuffle[idx];
|
||||
}
|
||||
|
||||
template <typename _TBlock, size_t (*_FuncGetShuffIdx)(size_t) = default_qs_shuff_idx>
|
||||
inline HVX_Vector make_qs_load_mask() {
|
||||
static_assert(sizeof(_TBlock) < hexagon::kBytesPerVector, "wrong block size/padding");
|
||||
|
||||
const size_t qs_start_offset = offsetof(_TBlock, qs);
|
||||
|
|
@ -58,23 +77,15 @@ template <typename _TBlock> inline HVX_Vector make_qs_load_mask() {
|
|||
for (size_t i = 0; i < hexagon::kBytesPerVector; ++i) {
|
||||
auto offset = i % sizeof(_TBlock);
|
||||
if (offset >= qs_start_offset && offset < qs_end_offset) {
|
||||
ret.u8[ret_idx++] = (i & 1) ? (i / 2 + 64) : (i / 2);
|
||||
size_t idx = _FuncGetShuffIdx(ret_idx);
|
||||
ret.u8[idx] = ((i & 1) ? (i / 2 + 64) : (i / 2));
|
||||
ret_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
return ret.v;
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock * srcs, HVX_VectorPred mask) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding");
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
|
||||
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
|
||||
|
||||
HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs);
|
||||
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale);
|
||||
return Q6_V_vmux_QVV(mask, blocks, block1);
|
||||
}
|
||||
|
||||
template <typename _TBlock>
|
||||
inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
|
||||
const HVX_Vector qs_indices,
|
||||
|
|
@ -84,10 +95,10 @@ inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
|
|||
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs);
|
||||
|
||||
HVX_Vector block01 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
|
||||
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 2);
|
||||
|
||||
HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0);
|
||||
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
|
||||
|
||||
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 2);
|
||||
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
|
||||
|
||||
if constexpr (sizeof(_TBlock) * 4 > hexagon::kBytesPerVector) {
|
||||
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 1);
|
||||
|
|
@ -139,15 +150,14 @@ inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * srcs,
|
|||
}
|
||||
|
||||
template <typename _TBlock>
|
||||
inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
|
||||
inline hexagon::HVX_Vector_x4 load_hexa_block_generic(const _TBlock * srcs,
|
||||
const HVX_Vector qs_indices,
|
||||
const HVX_Vector scale_indices) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding");
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
|
||||
|
||||
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs);
|
||||
|
||||
hexagon::HVX_Vector_x5 result;
|
||||
hexagon::HVX_Vector_x4 result;
|
||||
{
|
||||
HVX_Vector block012345 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
|
||||
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 1);
|
||||
|
|
@ -155,7 +165,6 @@ inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
|
|||
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 3);
|
||||
|
||||
result.val[0] = block012345;
|
||||
result.val[3] = Q6_V_vror_VR(block012345, kSizeOfQs * 4); // block45
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -173,7 +182,7 @@ inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
|
|||
|
||||
result.val[1] = scale01;
|
||||
result.val[2] = scale23;
|
||||
result.val[4] = scale45;
|
||||
result.val[3] = scale45;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
@ -198,15 +207,12 @@ inline HVX_VectorPair dequantize_vec_q40_qf32_2blocks(HVX_Vector qs, HVX_Vector
|
|||
|
||||
HVX_Vector q_lo = qs;
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * 4);
|
||||
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
q_lo = Q6_Vb_vshuff_Vb(q_lo);
|
||||
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
|
||||
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
scale01 = Q6_Vh_vshuff_Vh(scale01);
|
||||
q_lo = Q6_Vh_vshuff_Vh(q_lo); // TODO: avoid vshuff here
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
|
||||
return Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
|
||||
}
|
||||
|
|
@ -247,27 +253,49 @@ inline HVX_VectorPair_x2 dequantize_vec_q40_qf32_4blocks(HVX_Vector qs,
|
|||
HVX_Vector q_lo = qs;
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
|
||||
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * 4);
|
||||
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
q_lo = Q6_Vb_vshuff_Vb(q_lo);
|
||||
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
|
||||
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
q_hi = Q6_V_hi_W(qp0);
|
||||
|
||||
q_lo = Q6_Vh_vshuff_Vh(q_lo);
|
||||
scale01 = Q6_Vh_vshuff_Vh(scale01);
|
||||
|
||||
q_hi = Q6_Vh_vshuff_Vh(q_hi);
|
||||
scale23 = Q6_Vh_vshuff_Vh(scale23); // TODO: avoid vshuff here
|
||||
|
||||
hexagon::HVX_VectorPair_x2 result;
|
||||
result.val[0] = Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
|
||||
result.val[1] = Q6_Wqf32_vmpy_VhfVhf(q_hi, scale23);
|
||||
return result;
|
||||
}
|
||||
|
||||
inline HVX_VectorPair_x3 dequantize_vec_q40_qf32_6blocks(HVX_Vector qs,
|
||||
HVX_Vector scale01,
|
||||
HVX_Vector scale23,
|
||||
HVX_Vector scale45,
|
||||
HVX_Vector table) {
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
|
||||
|
||||
HVX_Vector q_lo = qs;
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
|
||||
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * 4);
|
||||
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
q_hi = Q6_V_hi_W(qp0);
|
||||
|
||||
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
|
||||
HVX_VectorPair qp1 = Q6_Wh_vlut16_VbVhR_nomatch(q_hi, table, 0);
|
||||
|
||||
q_lo = Q6_V_lo_W(qp0);
|
||||
q_hi = Q6_V_hi_W(qp0);
|
||||
HVX_Vector q2 = Q6_V_lo_W(qp1);
|
||||
|
||||
hexagon::HVX_VectorPair_x3 result;
|
||||
result.val[0] = Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
|
||||
result.val[1] = Q6_Wqf32_vmpy_VhfVhf(q_hi, scale23);
|
||||
result.val[2] = Q6_Wqf32_vmpy_VhfVhf(q2, scale45);
|
||||
return result;
|
||||
}
|
||||
|
||||
inline HVX_Vector load_dequant_vec_q40_qf32_1block(const npu_device_block_q4_0 * src,
|
||||
const HVX_Vector qs_indices,
|
||||
const HVX_Vector scale_indices,
|
||||
|
|
@ -277,14 +305,6 @@ inline HVX_Vector load_dequant_vec_q40_qf32_1block(const npu_device_block_q4_0 *
|
|||
return Q6_V_lo_W(dequantize_vec_q40_qf32_2blocks(qs.val[0], qs.val[1], table));
|
||||
}
|
||||
|
||||
inline HVX_Vector load_dequant_vec_q40_qf16_2blocks(const npu_device_block_q4_0 * src,
|
||||
const HVX_Vector qs_indices,
|
||||
const HVX_Vector scale_indices,
|
||||
const HVX_Vector table) {
|
||||
auto qs = load_dual_block_generic(src, qs_indices, scale_indices);
|
||||
return dequantize_vec_q40_qf16_2blocks(qs.val[0], qs.val[1], table);
|
||||
}
|
||||
|
||||
inline HVX_VectorPair load_dequant_vec_q40_qf32_2blocks(const npu_device_block_q4_0 * src,
|
||||
const HVX_Vector qs_indices,
|
||||
const HVX_Vector scale_indices,
|
||||
|
|
@ -301,4 +321,12 @@ inline HVX_VectorPair_x2 load_dequant_vec_q40_qf32_4blocks(const npu_device_bloc
|
|||
return dequantize_vec_q40_qf32_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
|
||||
}
|
||||
|
||||
inline HVX_VectorPair_x3 load_dequant_vec_q40_qf32_6blocks(const npu_device_block_q4_0 * src,
|
||||
const HVX_Vector qs_indices,
|
||||
const HVX_Vector scale_indices,
|
||||
const HVX_Vector table) {
|
||||
auto qs = load_hexa_block_generic(src, qs_indices, scale_indices);
|
||||
return dequantize_vec_q40_qf32_6blocks(qs.val[0], qs.val[1], qs.val[2], qs.val[3], table);
|
||||
}
|
||||
|
||||
} // namespace hexagon::vec::quant
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ void backend_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
size_t size) {
|
||||
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_set_tensor.size.%zu",
|
||||
(void *) get_buffer_object(buffer), size);
|
||||
|
||||
// TODO: use DMA instead of memcpy?
|
||||
memcpy((char *) tensor->data + offset, data, size);
|
||||
}
|
||||
|
||||
|
|
@ -76,12 +78,15 @@ void backend_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
size_t offset,
|
||||
size_t size) {
|
||||
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_get_tensor", (void *) get_buffer_object(buffer));
|
||||
|
||||
// TODO: use DMA instead of memcpy?
|
||||
memcpy(data, (const char *) tensor->data + offset, size);
|
||||
}
|
||||
|
||||
bool backend_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_cpy_tensor", (void *) get_buffer_object(buffer));
|
||||
if (ggml_backend_buffer_is_host(src->buffer)) {
|
||||
// TODO: use DMA instead of memcpy?
|
||||
memcpy(dst->data, src->data, ggml_nbytes(src));
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,12 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
|
|||
return NPU_OP_ROPE;
|
||||
case GGML_OP_GLU:
|
||||
return NPU_OP_GLU;
|
||||
case GGML_OP_GET_ROWS:
|
||||
return NPU_OP_GET_ROWS;
|
||||
case GGML_OP_SET_ROWS:
|
||||
return NPU_OP_SET_ROWS;
|
||||
case GGML_OP_CPY:
|
||||
return NPU_OP_CPY;
|
||||
default:
|
||||
return NPU_OP_COUNT;
|
||||
}
|
||||
|
|
@ -60,6 +66,12 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) {
|
|||
return ggml_op_name(GGML_OP_ROPE);
|
||||
case NPU_OP_GLU:
|
||||
return ggml_op_name(GGML_OP_GLU);
|
||||
case NPU_OP_GET_ROWS:
|
||||
return ggml_op_name(GGML_OP_GET_ROWS);
|
||||
case NPU_OP_SET_ROWS:
|
||||
return ggml_op_name(GGML_OP_SET_ROWS);
|
||||
case NPU_OP_CPY:
|
||||
return ggml_op_name(GGML_OP_CPY);
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
@ -73,6 +85,8 @@ enum npu_device_tensor_data_type type_to_npu_type(ggml_type type) {
|
|||
return NPU_DATA_TYPE_F16;
|
||||
case GGML_TYPE_I32:
|
||||
return NPU_DATA_TYPE_I32;
|
||||
case GGML_TYPE_I64:
|
||||
return NPU_DATA_TYPE_I64;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return NPU_DATA_TYPE_Q4_K;
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
|
@ -178,30 +192,15 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
|
|||
switch (dims) {
|
||||
default:
|
||||
case 4:
|
||||
snprintf(out,
|
||||
max_len,
|
||||
"%s[%ldx%ldx%ldx%ld]",
|
||||
ggml_type_name(tensor->type),
|
||||
(long) tensor->ne[0],
|
||||
(long) tensor->ne[1],
|
||||
(long) tensor->ne[2],
|
||||
(long) tensor->ne[3]);
|
||||
snprintf(out, max_len, "%s[%ldx%ldx%ldx%ld]", ggml_type_name(tensor->type), (long) tensor->ne[0],
|
||||
(long) tensor->ne[1], (long) tensor->ne[2], (long) tensor->ne[3]);
|
||||
break;
|
||||
case 3:
|
||||
snprintf(out,
|
||||
max_len,
|
||||
"%s[%ldx%ldx%ld]",
|
||||
ggml_type_name(tensor->type),
|
||||
(long) tensor->ne[0],
|
||||
(long) tensor->ne[1],
|
||||
(long) tensor->ne[2]);
|
||||
snprintf(out, max_len, "%s[%ldx%ldx%ld]", ggml_type_name(tensor->type), (long) tensor->ne[0],
|
||||
(long) tensor->ne[1], (long) tensor->ne[2]);
|
||||
break;
|
||||
case 2:
|
||||
snprintf(out,
|
||||
max_len,
|
||||
"%s[%ldx%ld]",
|
||||
ggml_type_name(tensor->type),
|
||||
(long) tensor->ne[0],
|
||||
snprintf(out, max_len, "%s[%ldx%ld]", ggml_type_name(tensor->type), (long) tensor->ne[0],
|
||||
(long) tensor->ne[1]);
|
||||
break;
|
||||
case 1:
|
||||
|
|
@ -233,14 +232,8 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
|
|||
print_tensor(dst->src[2], src2_desc, sizeof(src2_desc));
|
||||
char src3_desc[256];
|
||||
print_tensor(dst->src[3], src3_desc, sizeof(src3_desc));
|
||||
snprintf(out,
|
||||
max_len,
|
||||
"dst: %s, src0: %s, src1: %s, src2: %s, src3: %s",
|
||||
dst_desc,
|
||||
src0_desc,
|
||||
src1_desc,
|
||||
src2_desc,
|
||||
src3_desc);
|
||||
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s, src3: %s", dst_desc, src0_desc,
|
||||
src1_desc, src2_desc, src3_desc);
|
||||
return;
|
||||
}
|
||||
case 3:
|
||||
|
|
@ -251,8 +244,8 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
|
|||
print_tensor(dst->src[1], src1_desc, sizeof(src1_desc));
|
||||
char src2_desc[256];
|
||||
print_tensor(dst->src[2], src2_desc, sizeof(src2_desc));
|
||||
snprintf(
|
||||
out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s", dst_desc, src0_desc, src1_desc, src2_desc);
|
||||
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s", dst_desc, src0_desc, src1_desc,
|
||||
src2_desc);
|
||||
return;
|
||||
}
|
||||
case 2:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,9 @@ interface npu_device : remote_handle64{
|
|||
NPU_OP_FLASH_ATTN,
|
||||
NPU_OP_ROPE,
|
||||
NPU_OP_GLU,
|
||||
NPU_OP_GET_ROWS,
|
||||
NPU_OP_SET_ROWS,
|
||||
NPU_OP_CPY,
|
||||
NPU_OP_COUNT
|
||||
};
|
||||
|
||||
|
|
@ -70,6 +73,7 @@ interface npu_device : remote_handle64{
|
|||
NPU_DATA_TYPE_F32,
|
||||
NPU_DATA_TYPE_F16,
|
||||
NPU_DATA_TYPE_I32,
|
||||
NPU_DATA_TYPE_I64,
|
||||
NPU_DATA_TYPE_Q8_0,
|
||||
NPU_DATA_TYPE_Q4_0,
|
||||
NPU_DATA_TYPE_Q4_K,
|
||||
|
|
|
|||
Loading…
Reference in New Issue