Merge remote-tracking branch 'upstream/master' into backend-sampling
This commit is contained in:
commit
7816f0bb56
|
|
@ -1,9 +1,7 @@
|
|||
ARG UBUNTU_VERSION=25.10
|
||||
ARG UBUNTU_VERSION=26.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
|
|
|
|||
|
|
@ -351,16 +351,10 @@ jobs:
|
|||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
|
|
@ -374,13 +368,6 @@ jobs:
|
|||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Copy Libcurl
|
||||
id: prepare_libcurl
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}
|
||||
|
|
|
|||
|
|
@ -4183,6 +4183,21 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("RND1")
|
||||
class RND1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.RND1
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# RND1 specific parameters
|
||||
# RND1 uses bidirectional attention
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
|
||||
class Qwen3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -6,8 +6,54 @@ More Info:
|
|||
- https://github.com/ggml-org/llama.cpp/pull/14644
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14771
|
||||
|
||||
## Parameters
|
||||
The diffusion CLI supports various parameters to control the generation process:
|
||||
|
||||
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
|
||||
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
|
||||
### Scheduling Parameters
|
||||
Choose one of the following scheduling methods:
|
||||
|
||||
**Timestep-based scheduling:**
|
||||
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
|
||||
|
||||
**Block-based scheduling:**
|
||||
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
|
||||
|
||||
### Sampling Parameters
|
||||
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
|
||||
- `--top-k`: Top-k filtering for sampling
|
||||
- `--top-p`: Top-p (nucleus) filtering for sampling
|
||||
- `--seed`: Random seed for reproducibility
|
||||
|
||||
### Model Parameters
|
||||
- `-m`: Path to the GGUF model file
|
||||
- `-p`: Input prompt text
|
||||
- `-ub`: Maximum sequence length (ubatch size)
|
||||
- `-c`: Context size
|
||||
- `-b`: Batch size
|
||||
|
||||
### Examples
|
||||
#### Dream architechture:
|
||||
```
|
||||
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### LLaDA architechture:
|
||||
```
|
||||
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### RND1 architecture:
|
||||
```
|
||||
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
|
||||
```
|
||||
|
|
|
|||
|
|
@ -2303,9 +2303,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
// calculate rope cache for fist layer in current device.
|
||||
cann_ctx->rope_cache.cached = false;
|
||||
|
||||
bool cann_graph_update_required = false;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
bool cann_graph_update_required = false;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
|
|
@ -2336,7 +2336,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
}
|
||||
#else
|
||||
bool use_cann_graph = false;
|
||||
bool cann_graph_update_required = false;
|
||||
#endif // USE_ACL_GRAPH
|
||||
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
|
||||
|
||||
|
|
|
|||
|
|
@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#define AMD_MFMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(RDNA4)
|
||||
#define AMD_WMMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
|
||||
|
||||
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#define VOLTA_MMA_AVAILABLE
|
||||
|
|
@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) {
|
|||
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
}
|
||||
|
||||
static bool amd_wmma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_RDNA4(cc);
|
||||
}
|
||||
|
||||
static bool volta_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
|
|||
return __float2bfloat16(float(x));
|
||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||
return __float22half2_rn(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||
// bypass compile error on cuda 12.0.1
|
||||
#ifdef GGML_USE_HIP
|
||||
return __float22bfloat162_rn(x);
|
||||
#else
|
||||
return {x.x, x.y};
|
||||
#endif // GGML_USE_HIP
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
||||
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
|||
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
|
|
@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
|
|
@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
|||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
|
|
@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in
|
|||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_contiguous_cuda(
|
||||
static void ggml_cpy_scalar_contiguous_cuda(
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t, bool transposed = false>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
static void ggml_cpy_scalar_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
|
@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda(
|
|||
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
} else {
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
}
|
||||
|
|
@ -399,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, float, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<float, half>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q8_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q8_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q4_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_1_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q4_1_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_0_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_0_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_iq4_nl_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_1_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_q5_1_f32_cuda
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, half, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<half, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<half, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, half>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<nv_bfloat16, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
ggml_cpy_scalar_cuda<int32_t, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<float, int32_t>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
|
||||
(src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_scalar_cuda<int32_t, float>
|
||||
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
|
|
|
|||
|
|
@ -4119,6 +4119,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
|
|||
static constexpr int J = J_;
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(RDNA4)
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 16) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / 64;
|
||||
T x[ne] = {0};
|
||||
|
||||
|
|
@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
|
|||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // defined(RDNA4)
|
||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
|
@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
|
|||
return -1;
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
|
@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
|
|||
struct tile<I_, J_, nv_bfloat162> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
|
|
@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
|
|||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
};
|
||||
|
||||
template <int I, int J>
|
||||
|
|
@ -353,6 +436,8 @@ namespace ggml_cuda_mma {
|
|||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
|
|
@ -639,12 +724,34 @@ namespace ggml_cuda_mma {
|
|||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
||||
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
||||
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
||||
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
||||
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
return false;
|
||||
}
|
||||
} else {
|
||||
if (src1_ncols > 16) {
|
||||
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "mma.cuh"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
|
|
@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
|
|||
const int stride_col_id, const int stride_row_id,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
|
|
@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
|
|||
|
||||
if constexpr (!has_ids) {
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
} else {
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
|
|||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
|
|
@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
|
|||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
|
|
@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
|
|||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const float2 tmp = vals_buf[curr_buf][j0];
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
|
|
@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
|
|||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
|
|
@ -554,7 +585,8 @@ void mul_mat_f_cuda(
|
|||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A_16;
|
||||
typedef tile<32, 8, T> tile_A_32;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, T> tile_B_16;
|
||||
typedef tile< 8, 8, T> tile_B_8;
|
||||
|
||||
GGML_ASSERT(ncols_x % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
|
|
@ -581,7 +613,8 @@ void mul_mat_f_cuda(
|
|||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
|
|
|
|||
|
|
@ -43,6 +43,14 @@ set(HTP_CMAKE_ARGS
|
|||
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
|
||||
|
||||
ExternalProject_Add(htp-v68
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68")
|
||||
|
||||
ExternalProject_Add(htp-v69
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69")
|
||||
|
||||
ExternalProject_Add(htp-v73
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
|
||||
|
|
@ -61,6 +69,8 @@ ExternalProject_Add(htp-v81
|
|||
|
||||
# Install Hexagon skels required at runtime
|
||||
install(FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef _WIN32
|
||||
# include <sal.h>
|
||||
|
|
@ -240,6 +241,23 @@ struct ggml_hexagon_session {
|
|||
uint32_t prof_pkts;
|
||||
};
|
||||
|
||||
static inline void hex_print_op_info(const ggml_tensor * op, ggml_hexagon_session * sess, const uint32_t req_flags) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[64 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_strides(strides, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
|
||||
names, dims, types, strides, buffs, req_flags);
|
||||
}
|
||||
|
||||
void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
|
||||
// Bump pending flag (cleared in the session::flush once we get the responce)
|
||||
this->op_pending++; // atomic inc
|
||||
|
|
@ -1912,6 +1930,15 @@ static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_t
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename... _TTensor>
|
||||
static inline bool hex_supported_buffer(const struct ggml_hexagon_session * sess, _TTensor... tensors) {
|
||||
return ([&]() -> bool {
|
||||
return !tensors || !tensors->buffer ||
|
||||
(ggml_backend_buffer_is_hexagon(tensors->buffer) &&
|
||||
ggml_backend_hexagon_buffer_get_sess(tensors->buffer) == sess);
|
||||
}() && ...);
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
|
@ -1959,16 +1986,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
}
|
||||
|
||||
// src0 & src1 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2016,20 +2034,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
|||
|
||||
// src0 (weights) must be repacked and mapped to the same session
|
||||
// src1 & sr2 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src2->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, src2, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2063,16 +2068,7 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
|||
}
|
||||
|
||||
// src0, src1 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2104,20 +2100,7 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
|
|||
}
|
||||
|
||||
// src0, src1 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src2->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, src2, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2144,12 +2127,7 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
|||
}
|
||||
|
||||
// src0 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2186,16 +2164,7 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
|||
}
|
||||
|
||||
// src0, src1 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1 && src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2248,16 +2217,7 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
|||
}
|
||||
|
||||
// src0, src1 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1 && src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2269,7 +2229,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
|||
|
||||
int mode = op_params[2];
|
||||
|
||||
if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
|
||||
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
|
||||
return false;
|
||||
}
|
||||
if (mode & 1) {
|
||||
|
|
@ -2312,20 +2272,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
|||
}
|
||||
|
||||
// src0, src1, src2 & dst must be mapped to the same session
|
||||
if (src0->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (src2 && src2->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) {
|
||||
return false;
|
||||
}
|
||||
if (dst->buffer &&
|
||||
(!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) {
|
||||
if (!hex_supported_buffer(sess, src0, src1, src2, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2346,6 +2293,26 @@ static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) {
|
|||
h->nb[3] = t->nb[3];
|
||||
}
|
||||
|
||||
static size_t dspqueue_buffers_init(dspqueue_buffer * buf, const ggml_tensor * t, bool flush_host, bool flush_htp) {
|
||||
if (!t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
memset(buf, 0, sizeof(*buf));
|
||||
auto tensor_buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
|
||||
buf->fd = tensor_buf->fd;
|
||||
buf->ptr = t->data;
|
||||
buf->offset = (uint8_t *) t->data - tensor_buf->base;
|
||||
buf->size = ggml_nbytes(t);
|
||||
buf->flags = (flush_host ? DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER : 0); // Flush CPU
|
||||
buf->flags |= (flush_htp ? DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT : 0); // Invalidate DSP
|
||||
return 1;
|
||||
}
|
||||
|
||||
static ggml_hexagon_session * get_session_from_tensor(const ggml_tensor * t) {
|
||||
return static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context)->sess;
|
||||
}
|
||||
|
||||
static void hex_dump_dspbuf(const struct ggml_tensor * t, const dspqueue_buffer * d) {
|
||||
auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
|
||||
auto sess = buf->sess;
|
||||
|
|
@ -2360,10 +2327,6 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
|||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
|
||||
auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
|
||||
auto dst_buf = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = ggml_time_us();
|
||||
|
||||
|
|
@ -2385,55 +2348,27 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
|||
}
|
||||
|
||||
dspqueue_buffer bufs[3];
|
||||
memset(bufs, 0, sizeof(bufs));
|
||||
|
||||
// First buffer Weights.
|
||||
// The content is static, there is no need to do any cache management
|
||||
bufs[0].fd = src0_buf->fd;
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = 0;
|
||||
dspqueue_buffers_init(bufs, src0, false, false);
|
||||
|
||||
// Second buffer Input Activations. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
// invalidate DSP ones. On platforms with I/O coherency support the
|
||||
// framework will automatically skip cache operations where possible.
|
||||
bufs[1].fd = src1_buf->fd;
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
dspqueue_buffers_init(&bufs[1], src1, true, true);
|
||||
|
||||
// Third buffer Output Activations. We'll handle DSP
|
||||
// cache maintenance in the response message but need to flush
|
||||
// CPU caches to ensure any previously written dirty lines are
|
||||
// written out before writes from the DSP start.
|
||||
bufs[2].fd = dst_buf->fd;
|
||||
bufs[2].ptr = dst->data;
|
||||
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[2].size = ggml_nbytes(dst);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
dspqueue_buffers_init(&bufs[2], dst, true, false);
|
||||
|
||||
// Primary DSP session from the src0 (normally weight) tensor
|
||||
auto sess = src0_buf->sess;
|
||||
auto * sess = get_session_from_tensor(src0);
|
||||
|
||||
if (opt_verbose) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[64 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_strides(strides, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
|
||||
names, dims, types, strides, buffs, req.flags);
|
||||
hex_print_op_info(op, sess, req.flags);
|
||||
if (opt_verbose > 1) {
|
||||
hex_dump_dspbuf(src0, &bufs[0]);
|
||||
hex_dump_dspbuf(src1, &bufs[1]);
|
||||
|
|
@ -2463,11 +2398,6 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
|||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
|
||||
auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
|
||||
auto src2_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src2->buffer->context);
|
||||
auto dst_buf = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = ggml_time_us();
|
||||
|
||||
|
|
@ -2490,66 +2420,32 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
|||
}
|
||||
|
||||
dspqueue_buffer bufs[4];
|
||||
memset(bufs, 0, sizeof(bufs));
|
||||
|
||||
// First buffer Weights.
|
||||
// The content is static, there is no need to do any cache management
|
||||
bufs[0].fd = src0_buf->fd;
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = 0;
|
||||
dspqueue_buffers_init(bufs, src0, false, false);
|
||||
|
||||
// Second buffer Input Activations. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
// invalidate DSP ones. On platforms with I/O coherency support the
|
||||
// framework will automatically skip cache operations where possible.
|
||||
bufs[1].fd = src1_buf->fd;
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
dspqueue_buffers_init(&bufs[1], src1, true, true);
|
||||
|
||||
// Third buffer expert IDs. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
// invalidate DSP ones. On platforms with I/O coherency support the
|
||||
// framework will automatically skip cache operations where possible.
|
||||
bufs[2].fd = src2_buf->fd;
|
||||
bufs[2].ptr = src2->data;
|
||||
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[2].size = ggml_nbytes(src2);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
dspqueue_buffers_init(&bufs[2], src2, true, true);
|
||||
|
||||
// Forth buffer Output Activations. We'll handle DSP
|
||||
// cache maintenance in the response message but need to flush
|
||||
// CPU caches to ensure any previously written dirty lines are
|
||||
// written out before writes from the DSP start.
|
||||
bufs[3].fd = dst_buf->fd;
|
||||
bufs[3].ptr = dst->data;
|
||||
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[3].size = ggml_nbytes(dst);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
dspqueue_buffers_init(&bufs[3], dst, true, false);
|
||||
|
||||
// Primary DSP session from the src0 (normally weight) tensor
|
||||
auto sess = src0_buf->sess;
|
||||
auto * sess = get_session_from_tensor(src0);
|
||||
|
||||
if (opt_verbose) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[64 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
|
||||
names, dims, types, strides, buffs, req.flags);
|
||||
|
||||
hex_print_op_info(op, sess, req.flags);
|
||||
if (opt_verbose > 1) {
|
||||
hex_dump_dspbuf(src0, &bufs[0]);
|
||||
hex_dump_dspbuf(src1, &bufs[1]);
|
||||
|
|
@ -2581,10 +2477,6 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
|||
const struct ggml_tensor * src1 = node->src[1];
|
||||
const struct ggml_tensor * dst = node;
|
||||
|
||||
auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
|
||||
auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
|
||||
auto dst_buf = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
|
||||
|
||||
uint64_t t1 = 0;
|
||||
uint64_t t2 = 0;
|
||||
|
||||
|
|
@ -2621,60 +2513,30 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
|||
init_htp_tensor(&req.dst, dst);
|
||||
|
||||
dspqueue_buffer bufs[3];
|
||||
memset(bufs, 0, sizeof(bufs));
|
||||
|
||||
// First buffer = First Operand of Binary op
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
bufs[0].fd = src0_buf->fd;
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
dspqueue_buffers_init(bufs, src0, true, true);
|
||||
|
||||
// Second buffer = Second Operand of Binary op
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
bufs[1].fd = src1_buf->fd;
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
dspqueue_buffers_init(&bufs[1], src1, true, true);
|
||||
|
||||
// Third buffer = Output Activations. We'll handle DSP
|
||||
// cache maintenance in the response message but need to flush
|
||||
// CPU caches to ensure any previously written dirty lines are
|
||||
// written out before writes from the DSP start.
|
||||
bufs[2].fd = dst_buf->fd;
|
||||
bufs[2].ptr = dst->data;
|
||||
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[2].size = ggml_nbytes(dst);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
dspqueue_buffers_init(&bufs[2], dst, true, false);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
auto * sess = get_session_from_tensor(src0);
|
||||
|
||||
if (opt_verbose) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[16 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_strides(strides, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(),
|
||||
ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags);
|
||||
hex_print_op_info(op, sess, req.flags);
|
||||
if (opt_verbose > 1) {
|
||||
hex_dump_dspbuf(src0, &bufs[0]);
|
||||
hex_dump_dspbuf(src1, &bufs[1]);
|
||||
|
|
@ -2705,11 +2567,6 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
|||
const struct ggml_tensor * src2 = node->src[2];
|
||||
const struct ggml_tensor * dst = node;
|
||||
|
||||
auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
|
||||
auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
|
||||
auto src2_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src2->buffer->context);
|
||||
auto dst_buf = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
|
||||
|
||||
uint64_t t1 = 0;
|
||||
uint64_t t2 = 0;
|
||||
|
||||
|
|
@ -2741,58 +2598,19 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
|||
init_htp_tensor(&req.dst, dst);
|
||||
|
||||
dspqueue_buffer bufs[4];
|
||||
memset(bufs, 0, sizeof(bufs));
|
||||
|
||||
// First buffer = input activations
|
||||
bufs[0].fd = src0_buf->fd;
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
|
||||
dspqueue_buffers_init(bufs, src0, true, true);
|
||||
// Second buffer = experts bias
|
||||
bufs[1].fd = src1_buf->fd;
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
dspqueue_buffers_init(&bufs[1], src1, true, true);
|
||||
// Third buffer = activated experts
|
||||
bufs[2].fd = src2_buf->fd;
|
||||
bufs[2].ptr = src2->data;
|
||||
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[2].size = ggml_nbytes(src2);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
dspqueue_buffers_init(&bufs[2], src2, true, true);
|
||||
// Forth buffer = output activations
|
||||
bufs[3].fd = dst_buf->fd;
|
||||
bufs[3].ptr = dst->data;
|
||||
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[3].size = ggml_nbytes(dst);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
dspqueue_buffers_init(&bufs[3], dst, true, true);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
auto * sess = get_session_from_tensor(src0);
|
||||
|
||||
if (opt_verbose) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[16 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_strides(strides, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(),
|
||||
ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags);
|
||||
|
||||
hex_print_op_info(op, sess, req.flags);
|
||||
if (opt_verbose > 1) {
|
||||
hex_dump_dspbuf(src0, &bufs[0]);
|
||||
hex_dump_dspbuf(src1, &bufs[1]);
|
||||
|
|
@ -2886,71 +2704,33 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
|||
}
|
||||
|
||||
dspqueue_buffer bufs[3];
|
||||
int n_bufs = 0;
|
||||
|
||||
memset(bufs, 0, sizeof(bufs));
|
||||
|
||||
// First buffer = Only Operand of Unary op
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
|
||||
bufs[n_bufs].fd = src0_buf->fd;
|
||||
bufs[n_bufs].ptr = src0->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src0);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
++n_bufs;
|
||||
size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true);
|
||||
|
||||
if (src1) {
|
||||
// Second buffer = Second Operand of Binary op
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
|
||||
bufs[n_bufs].fd = src1_buf->fd;
|
||||
bufs[n_bufs].ptr = src1->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src1);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
}
|
||||
// Second buffer(nullable) = Second Operand of Binary op
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true);
|
||||
|
||||
// Second or third buffer = Output Activations. We'll handle DSP
|
||||
// Second buffer = Output Activations. We'll handle DSP
|
||||
// cache maintenance in the response message but need to flush
|
||||
// CPU caches to ensure any previously written dirty lines are
|
||||
// written out before writes from the DSP start.
|
||||
auto dst_buf = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
|
||||
bufs[n_bufs].fd = dst_buf->fd;
|
||||
bufs[n_bufs].ptr = dst->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(dst);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
++n_bufs;
|
||||
n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
auto * sess = get_session_from_tensor(src0);
|
||||
|
||||
if (opt_verbose) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[64 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_strides(strides, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
|
||||
names, dims, types, strides, buffs, req.flags);
|
||||
hex_print_op_info(op, sess, req.flags);
|
||||
if (opt_verbose > 1) {
|
||||
hex_dump_dspbuf(src0, &bufs[0]);
|
||||
if (src1) {
|
||||
|
|
@ -3023,85 +2803,40 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
|||
}
|
||||
|
||||
dspqueue_buffer bufs[4];
|
||||
int n_bufs = 0;
|
||||
|
||||
memset(bufs, 0, sizeof(bufs));
|
||||
|
||||
// First buffer
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
auto src0_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src0->buffer->context);
|
||||
bufs[n_bufs].fd = src0_buf->fd;
|
||||
bufs[n_bufs].ptr = src0->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src0);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
++n_bufs;
|
||||
size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true);
|
||||
|
||||
// Second buffer
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
auto src1_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src1->buffer->context);
|
||||
bufs[n_bufs].fd = src1_buf->fd;
|
||||
bufs[n_bufs].ptr = src1->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src1);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true);
|
||||
|
||||
if (src2) {
|
||||
// Third buffer
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
auto src2_buf = static_cast<ggml_backend_hexagon_buffer_context *>(src2->buffer->context);
|
||||
bufs[n_bufs].fd = src2_buf->fd;
|
||||
bufs[n_bufs].ptr = src2->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src2);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
}
|
||||
// Third buffer(nullable)
|
||||
// This is a buffer that the CPU writes and the DSP reads, so we'll
|
||||
// need to flush CPU caches and invalidate DSP ones. On platforms
|
||||
// with I/O coherency support the framework will automatically skip
|
||||
// cache operations where possible.
|
||||
n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src2, true, true);
|
||||
|
||||
// Final buffer = Output Activations. We'll handle DSP
|
||||
// Second buffer = Output Activations. We'll handle DSP
|
||||
// cache maintenance in the response message but need to flush
|
||||
// CPU caches to ensure any previously written dirty lines are
|
||||
// written out before writes from the DSP start.
|
||||
auto dst_buf = static_cast<ggml_backend_hexagon_buffer_context *>(dst->buffer->context);
|
||||
bufs[n_bufs].fd = dst_buf->fd;
|
||||
bufs[n_bufs].ptr = dst->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(dst);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
++n_bufs;
|
||||
n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
auto * sess = get_session_from_tensor(src0);
|
||||
|
||||
if (opt_verbose) {
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char strides[64 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
hex_format_op_dims(dims, op);
|
||||
hex_format_op_strides(strides, op);
|
||||
hex_format_op_types(types, op);
|
||||
hex_format_op_buffs(buffs, op);
|
||||
hex_format_op_names(names, op);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op),
|
||||
names, dims, types, strides, buffs, req.flags);
|
||||
hex_print_op_info(op, sess, req.flags);
|
||||
if (opt_verbose > 1) {
|
||||
hex_dump_dspbuf(src0, &bufs[0]);
|
||||
if (src1) {
|
||||
|
|
|
|||
|
|
@ -390,6 +390,12 @@ int get_hex_arch_ver(int domain, int * arch) {
|
|||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x68:
|
||||
*arch = 68;
|
||||
return 0;
|
||||
case 0x69:
|
||||
*arch = 69;
|
||||
return 0;
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -66,6 +66,13 @@ static inline bool dma_queue_push(dma_queue * q,
|
|||
desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
|
||||
desc->dstbypass = 1;
|
||||
desc->srcbypass = 1;
|
||||
#if __HVX_ARCH__ >= 73
|
||||
desc->dstbypass = 1;
|
||||
desc->srcbypass = 1;
|
||||
#else
|
||||
desc->dstbypass = 0;
|
||||
desc->srcbypass = 1;
|
||||
#endif
|
||||
desc->order = 0;
|
||||
desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
|
||||
desc->src = (void *) src;
|
||||
|
|
|
|||
|
|
@ -16,13 +16,8 @@
|
|||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) {
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.02f; // log(INF)
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
|
||||
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
|
||||
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
|
||||
|
||||
HVX_Vector out = hvx_vec_exp_fp32(in_vec);
|
||||
|
||||
|
|
@ -47,6 +42,12 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
|||
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.02f; // log(INF)
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
|
@ -55,9 +56,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
|||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
|
||||
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in);
|
||||
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++);
|
||||
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -67,9 +68,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
|||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -83,9 +84,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
|||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||
|
||||
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in);
|
||||
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
vec_out = hvx_vec_exp_fp32_guard(in);
|
||||
vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf);
|
||||
}
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
||||
|
|
|
|||
|
|
@ -16,6 +16,15 @@
|
|||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
|
||||
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
|
||||
|
||||
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
|
||||
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
|
||||
|
||||
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
|
||||
}
|
||||
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
|
@ -32,19 +41,22 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
|
|||
FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
static const uint32_t kNanInfMask = 0x7f800000;
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++);
|
||||
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -53,7 +65,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
|
|||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
HVX_Vector out = hvx_vec_inverse_fp32_guard(in);
|
||||
HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,26 @@ typedef union {
|
|||
float fp32[VLEN_FP32];
|
||||
} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector int32_to_qfloat(HVX_Vector const in)
|
||||
{
|
||||
HVX_Vector const vzero = Q6_V_vzero();
|
||||
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
|
||||
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
|
||||
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
|
||||
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
|
||||
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
|
||||
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
||||
{
|
||||
return Q6_Vsf_equals_Vqf32(int32_to_qfloat(in));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||
union {
|
||||
float f;
|
||||
|
|
@ -726,24 +746,6 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
|
|||
return Q6_Vsf_equals_Vqf32(r_qf);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) {
|
||||
static const float kInf = INFINITY;
|
||||
static const uint32_t kNanMask = 0x7fffffff;
|
||||
static const uint32_t kNanMin = 0x7f800000;
|
||||
|
||||
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf);
|
||||
|
||||
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
|
||||
|
||||
const HVX_Vector nan_mask = Q6_V_vsplat_R(kNanMask);
|
||||
const HVX_Vector nan_min = Q6_V_vsplat_R(kNanMin);
|
||||
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_mask);
|
||||
const HVX_VectorPred pred = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out);
|
||||
|
||||
return Q6_V_vmux_QVV(pred, out, Q6_V_vzero());
|
||||
}
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
|
||||
|
|
@ -958,14 +960,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
|
|||
return Q6_Vsf_equals_Vqf32(temp);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) {
|
||||
static const float kMaxExp = -88.02f; // log(INF)
|
||||
|
||||
const HVX_Vector max_exp = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp));
|
||||
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp);
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
|
||||
HVX_Vector one,
|
||||
HVX_Vector max_exp,
|
||||
HVX_Vector min_exp) {
|
||||
const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
|
||||
const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
|
||||
|
||||
HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v);
|
||||
return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero());
|
||||
out = Q6_V_vmux_QVV(pred_max, out, one);
|
||||
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
|
||||
}
|
||||
|
||||
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
|
|
@ -977,9 +981,16 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
|
|||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
static const float kMinExp = -87.f; // 0
|
||||
static const float kMaxExp = 87.f; // 1
|
||||
|
||||
const HVX_Vector one = hvx_vec_splat_fp32(1.f);
|
||||
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]);
|
||||
v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -143,16 +143,25 @@ AEEResult htp_iface_disable_etm(remote_handle64 handle) {
|
|||
}
|
||||
|
||||
static int vtcm_acquire(struct htp_context * ctx) {
|
||||
int err;
|
||||
if (!ctx->vtcm_valid) {
|
||||
// Temporarily bump thread priority to make sure it's higher than other sessions.
|
||||
// This way the resource manager will notify the other thread to release VTCM.
|
||||
// Note that we need to reaquire VTCM at normal priority for this to work next time.
|
||||
qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10);
|
||||
HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||
err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||
if (err != 0) {
|
||||
FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err);
|
||||
abort();
|
||||
}
|
||||
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||
qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio);
|
||||
|
||||
HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||
err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||
if (err != 0) {
|
||||
FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err);
|
||||
abort();
|
||||
}
|
||||
ctx->vtcm_valid = true;
|
||||
}
|
||||
|
||||
|
|
@ -201,7 +210,7 @@ static int vtcm_alloc(struct htp_context * ctx) {
|
|||
HAP_compute_res_attr_init(&attr);
|
||||
HAP_compute_res_attr_set_serialize(&attr, 0);
|
||||
HAP_compute_res_attr_set_cache_mode(&attr, 1);
|
||||
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size);
|
||||
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size);
|
||||
HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
|
||||
HAP_compute_res_attr_set_hmx_param(&attr, 1);
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@
|
|||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
|
||||
#define HTP_ROPE_TYPE_NORMAL 0
|
||||
#define HTP_ROPE_TYPE_NEOX 2
|
||||
|
||||
#define htp_rope_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
|
|
@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
|
|||
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[num_elems/2];
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
//src += 1;
|
||||
//dst += 1;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
int half_size = (sizeof(float) * (num_elems / 2));
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
|
||||
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
|
||||
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
|
||||
|
||||
src0_curr += VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += VLEN;
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
|
|
@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
|||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const int32_t mode = rope_ctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
|
@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
|||
float * dst_data_loc = dst_data;
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
if (is_neox) {
|
||||
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
} else {
|
||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||
const float cos_theta = wp0[i0 + 0];
|
||||
const float sin_theta = wp0[i0 + 1];
|
||||
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[1];
|
||||
if (is_neox) {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[rope_ctx->n_dims/2];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
src_loc += 1;
|
||||
dst_data_loc += 1;
|
||||
} else {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[1];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6895,9 +6895,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
cl_context context = backend_ctx->context;
|
||||
|
||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
|
||||
// For KQ
|
||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
|
||||
nb00 <= nb02 &&
|
||||
nb02 <= nb01 &&
|
||||
nb01 <= nb03 &&
|
||||
nb10 <= nb12 &&
|
||||
nb12 <= nb11 &&
|
||||
nb11 <= nb13) {
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
// For KQV
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11381,13 +11381,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
|
||||
static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
||||
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
|
||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
if (ggml_is_empty(node) || !node->buffer) {
|
||||
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -11399,132 +11399,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
ggml_tensor * src2 = node->src[2];
|
||||
ggml_tensor * src3 = node->src[3];
|
||||
|
||||
switch (node->op) {
|
||||
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NONE:
|
||||
return false;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
switch (ggml_get_glu_op(node)) {
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
||||
if (next_node_idx < cgraph->n_nodes &&
|
||||
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
||||
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
||||
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
||||
ctx->device->add_rms_fusion) {
|
||||
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
||||
ctx->do_add_rms_partials_offset_calculation = true;
|
||||
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
|
||||
ctx->do_add_rms_partials = true;
|
||||
}
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
||||
if (next_node_idx < cgraph->n_nodes &&
|
||||
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
||||
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
||||
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
||||
ctx->device->add_rms_fusion) {
|
||||
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
||||
ctx->do_add_rms_partials_offset_calculation = true;
|
||||
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
|
||||
ctx->do_add_rms_partials = true;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
break;
|
||||
default:
|
||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
vk_context compute_ctx;
|
||||
|
|
@ -11961,145 +11848,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
ctx->compute_ctx.reset();
|
||||
|
||||
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
|
||||
if (!ok) {
|
||||
if (node->op == GGML_OP_UNARY) {
|
||||
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
||||
} else if (node->op == GGML_OP_GLU) {
|
||||
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
||||
} else {
|
||||
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
|
||||
static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
|
||||
GGML_UNUSED(cgraph);
|
||||
ggml_backend_buffer * buf = nullptr;
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
switch (ggml_get_glu_op(tensor)) {
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (buf == nullptr) {
|
||||
return false;
|
||||
}
|
||||
GGML_UNUSED(tensor);
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
|
||||
|
||||
|
|
@ -12143,8 +11899,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
subctx->out_memcpys.clear();
|
||||
subctx->memsets.clear();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Clean up after graph processing is done
|
||||
|
|
|
|||
|
|
@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum):
|
|||
APERTUS = auto()
|
||||
COGVLM = auto()
|
||||
MINIMAXM2 = auto()
|
||||
RND1 = auto()
|
||||
PANGU_EMBED = auto()
|
||||
|
||||
|
||||
|
|
@ -797,6 +798,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.APERTUS: "apertus",
|
||||
MODEL_ARCH.MINIMAXM2: "minimax-m2",
|
||||
MODEL_ARCH.COGVLM: "cogvlm",
|
||||
MODEL_ARCH.RND1: "rnd1",
|
||||
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
|
||||
}
|
||||
|
||||
|
|
@ -2991,6 +2993,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.VISEXP_UP,
|
||||
MODEL_TENSOR.VISEXP_DOWN,
|
||||
],
|
||||
MODEL_ARCH.RND1: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.PANGU_EMBED: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ add_library(llama
|
|||
models/qwen3vl-moe.cpp
|
||||
models/qwen3moe.cpp
|
||||
models/refact.cpp
|
||||
models/rnd1.cpp
|
||||
models/rwkv6-base.cpp
|
||||
models/rwkv6.cpp
|
||||
models/rwkv6qwen2.cpp
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_APERTUS, "apertus" },
|
||||
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
|
||||
{ LLM_ARCH_COGVLM, "cogvlm" },
|
||||
{ LLM_ARCH_RND1, "rnd1" },
|
||||
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
|
@ -2446,6 +2447,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RND1,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
|
@ -2722,6 +2743,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
|
|||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
case LLM_ARCH_RND1:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ enum llm_arch {
|
|||
LLM_ARCH_APERTUS,
|
||||
LLM_ARCH_MINIMAX_M2,
|
||||
LLM_ARCH_COGVLM,
|
||||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1036,6 +1036,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_RND1:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 48: type = LLM_TYPE_30B_A3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
// Set non-causal attention for diffusion models
|
||||
hparams.causal_attn = false;
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
|
|
@ -3402,6 +3414,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
} break;
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
case LLM_ARCH_QWEN3VLMOE:
|
||||
case LLM_ARCH_RND1:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
|
@ -6720,7 +6733,7 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) {
|
||||
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
}
|
||||
|
||||
|
|
@ -6882,6 +6895,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
case LLM_ARCH_RND1:
|
||||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
|
|
@ -7075,6 +7089,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
llm = std::make_unique<llm_build_llada_moe>(*this, params);
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_RND1:
|
||||
{
|
||||
llm = std::make_unique<llm_build_rnd1>(*this, params);
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen2vl>(*this, params);
|
||||
|
|
@ -7598,6 +7617,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_QWEN3:
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
case LLM_ARCH_RND1:
|
||||
case LLM_ARCH_OLMO2:
|
||||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_PHI2:
|
||||
|
|
|
|||
|
|
@ -431,6 +431,10 @@ struct llm_build_refact : public llm_graph_context {
|
|||
llm_build_refact(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_rnd1 : public llm_graph_context {
|
||||
llm_build_rnd1(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
||||
llm_build_rwkv6(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,126 @@
|
|||
#include "models.h"
|
||||
|
||||
// RND1 is a Qwen3Moe AR model converted to diffusion model.
|
||||
llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// Non-causal attention for diffusion
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * moe_out =
|
||||
build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
cur = moe_out;
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -6953,9 +6953,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
|
||||
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||
for (bool use_view_slice : { true, false }) {
|
||||
for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
|
||||
{2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {
|
||||
|
|
@ -7819,6 +7821,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||
}
|
||||
}
|
||||
|
|
@ -7827,6 +7830,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
||||
}
|
||||
}
|
||||
|
|
@ -7837,6 +7841,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
for (int bs : {1, 4, 8, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -29,6 +29,7 @@
|
|||
sendMessage,
|
||||
stopGeneration
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import {
|
||||
supportsVision,
|
||||
supportsAudio,
|
||||
|
|
@ -47,6 +48,7 @@
|
|||
|
||||
let { showCenteredEmpty = false } = $props();
|
||||
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
|
||||
let autoScrollEnabled = $state(true);
|
||||
let chatScrollContainer: HTMLDivElement | undefined = $state();
|
||||
let dragCounter = $state(0);
|
||||
|
|
@ -149,7 +151,7 @@
|
|||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!chatScrollContainer) return;
|
||||
if (disableAutoScroll || !chatScrollContainer) return;
|
||||
|
||||
const { scrollTop, scrollHeight, clientHeight } = chatScrollContainer;
|
||||
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
|
||||
|
|
@ -194,8 +196,10 @@
|
|||
const extras = result?.extras;
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
if (!disableAutoScroll) {
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
}
|
||||
await sendMessage(message, extras);
|
||||
scrollChatToBottom();
|
||||
|
||||
|
|
@ -241,6 +245,8 @@
|
|||
}
|
||||
|
||||
function scrollChatToBottom(behavior: ScrollBehavior = 'smooth') {
|
||||
if (disableAutoScroll) return;
|
||||
|
||||
chatScrollContainer?.scrollTo({
|
||||
top: chatScrollContainer?.scrollHeight,
|
||||
behavior
|
||||
|
|
@ -248,14 +254,27 @@
|
|||
}
|
||||
|
||||
afterNavigate(() => {
|
||||
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (disableAutoScroll) {
|
||||
autoScrollEnabled = false;
|
||||
if (scrollInterval) {
|
||||
clearInterval(scrollInterval);
|
||||
scrollInterval = undefined;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (isCurrentConversationLoading && autoScrollEnabled) {
|
||||
scrollInterval = setInterval(scrollChatToBottom, AUTO_SCROLL_INTERVAL);
|
||||
} else if (scrollInterval) {
|
||||
|
|
@ -289,9 +308,11 @@
|
|||
class="mb-16 md:mb-24"
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
scrollChatToBottom();
|
||||
if (!disableAutoScroll) {
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
scrollChatToBottom();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
Settings,
|
||||
Funnel,
|
||||
AlertTriangle,
|
||||
Brain,
|
||||
Code,
|
||||
Monitor,
|
||||
Sun,
|
||||
|
|
@ -58,6 +57,33 @@
|
|||
label: 'Paste long text to file length',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'enableContinueGeneration',
|
||||
label: 'Enable "Continue" button',
|
||||
type: 'checkbox',
|
||||
isExperimental: true
|
||||
},
|
||||
{
|
||||
key: 'pdfAsImage',
|
||||
label: 'Parse PDF as image',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'askForTitleConfirmation',
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
title: 'Display',
|
||||
icon: Monitor,
|
||||
fields: [
|
||||
{
|
||||
key: 'showThoughtInProgress',
|
||||
label: 'Show thought in progress',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showMessageStats',
|
||||
label: 'Show message generation statistics',
|
||||
|
|
@ -79,25 +105,14 @@
|
|||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'enableContinueGeneration',
|
||||
label: 'Enable "Continue" button',
|
||||
type: 'checkbox',
|
||||
isExperimental: true
|
||||
},
|
||||
{
|
||||
key: 'pdfAsImage',
|
||||
label: 'Parse PDF as image',
|
||||
key: 'disableAutoScroll',
|
||||
label: 'Disable automatic scroll',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'askForTitleConfirmation',
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
@ -218,17 +233,6 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
title: 'Reasoning',
|
||||
icon: Brain,
|
||||
fields: [
|
||||
{
|
||||
key: 'showThoughtInProgress',
|
||||
label: 'Show thought in progress',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
title: 'Import/Export',
|
||||
icon: Database,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
showModelInfo: false,
|
||||
disableAutoScroll: false,
|
||||
renderUserContentAsMarkdown: false,
|
||||
modelSelectorEnabled: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
|
|
@ -99,6 +100,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
|
||||
showModelInfo: 'Display the model name used to generate each message below the message content.',
|
||||
disableAutoScroll:
|
||||
'Disable automatic scrolling while messages stream so you can control the viewport position manually.',
|
||||
renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.',
|
||||
modelSelectorEnabled:
|
||||
'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.',
|
||||
|
|
|
|||
Loading…
Reference in New Issue