Merge branch 'ggml-org:master' into qwen3_next
This commit is contained in:
commit
c78f9fce68
|
|
@ -52,3 +52,11 @@ insert_final_newline = unset
|
|||
[vendor/miniaudio/miniaudio.h]
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[tools/server/webui/**]
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
|
|
|||
|
|
@ -76,51 +76,206 @@ jobs:
|
|||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
# Setup nodejs (to be used for verifying bundled index.html)
|
||||
- uses: actions/setup-node@v4
|
||||
webui-setup:
|
||||
name: WebUI Setup
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
node-version: '22.11.0'
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: WebUI - Install dependencies
|
||||
id: webui_lint
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Cache node_modules
|
||||
uses: actions/cache@v4
|
||||
id: cache-node-modules
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-check:
|
||||
needs: webui-setup
|
||||
name: WebUI Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Run type checking
|
||||
run: npm run check
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run linting
|
||||
run: npm run lint
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-build:
|
||||
needs: webui-check
|
||||
name: WebUI Build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Build application
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-tests:
|
||||
needs: webui-build
|
||||
name: Run WebUI tests
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build Storybook
|
||||
run: npm run build-storybook
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run Client tests
|
||||
run: npm run test:client
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run Server tests
|
||||
run: npm run test:server
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run UI tests
|
||||
run: npm run test:ui
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run E2E tests
|
||||
run: npm run test:e2e
|
||||
working-directory: tools/server/webui
|
||||
|
||||
server-build:
|
||||
needs: [webui-tests]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
|
||||
build_type: [RelWithDebInfo]
|
||||
include:
|
||||
- build_type: Release
|
||||
sanitizer: ""
|
||||
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
|
||||
|
||||
steps:
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
cd tools/server/webui
|
||||
npm ci
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install \
|
||||
build-essential \
|
||||
xxd \
|
||||
git \
|
||||
cmake \
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libcurl4-openssl-dev
|
||||
|
||||
- name: WebUI - Check code format
|
||||
id: webui_format
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
git config --global --add safe.directory $(realpath .)
|
||||
cd tools/server/webui
|
||||
git status
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
npm run format
|
||||
git status
|
||||
modified_files="$(git status -s)"
|
||||
echo "Modified files: ${modified_files}"
|
||||
if [ -n "${modified_files}" ]; then
|
||||
echo "Files do not follow coding style. To fix: npm run format"
|
||||
echo "${modified_files}"
|
||||
exit 1
|
||||
fi
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Verify bundled index.html
|
||||
id: verify_server_index_html
|
||||
run: |
|
||||
git config --global --add safe.directory $(realpath .)
|
||||
cd tools/server/webui
|
||||
git status
|
||||
- name: Install WebUI dependencies
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
npm run build
|
||||
git status
|
||||
modified_files="$(git status -s)"
|
||||
echo "Modified files: ${modified_files}"
|
||||
if [ -n "${modified_files}" ]; then
|
||||
echo "Repository is dirty or server/webui is not built as expected"
|
||||
echo "Hint: You may need to follow Web UI build guide in server/README.md"
|
||||
echo "${modified_files}"
|
||||
exit 1
|
||||
fi
|
||||
- name: Build WebUI
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
id: cmake_build_no_openmp
|
||||
|
|
|
|||
|
|
@ -148,3 +148,7 @@ poetry.toml
|
|||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
.ccache/
|
||||
|
||||
# Code Workspace
|
||||
*.code-workspace
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
#### Tailwind & CSS
|
||||
|
||||
- We are using Tailwind v4 which uses oklch colors so we now want to refer to the CSS vars directly, without wrapping it with any color function like `hsla/hsl`, `rgba` etc.
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
# Coding rules
|
||||
|
||||
## Svelte & SvelteKit
|
||||
|
||||
### Services vs Stores Separation Pattern
|
||||
|
||||
#### `lib/services/` - Pure Business Logic
|
||||
|
||||
- **Purpose**: Stateless business logic and external communication
|
||||
- **Contains**:
|
||||
- API calls to external services (ApiService)
|
||||
- Pure business logic functions (ChatService, etc.)
|
||||
- **Rules**:
|
||||
- NO Svelte runes ($state, $derived, $effect)
|
||||
- NO reactive state management
|
||||
- Pure functions and classes only
|
||||
- Can import types but not stores
|
||||
- Focus on "how" - implementation details
|
||||
|
||||
#### `lib/stores/` - Reactive State Management
|
||||
|
||||
- **Purpose**: Svelte-specific reactive state with runes
|
||||
- **Contains**:
|
||||
- Reactive state classes with $state, $derived, $effect
|
||||
- Database operations (DatabaseStore)
|
||||
- UI-focused state management
|
||||
- Store orchestration logic
|
||||
- **Rules**:
|
||||
- USE Svelte runes for reactivity
|
||||
- Import and use services for business logic
|
||||
- NO direct database operations
|
||||
- NO direct API calls (use services)
|
||||
- Focus on "what" - reactive state for UI
|
||||
|
||||
#### Enforcement
|
||||
|
||||
- Services should be testable without Svelte
|
||||
- Stores should leverage Svelte's reactivity system
|
||||
- Clear separation: services handle data, stores handle state
|
||||
- Services can be reused across multiple stores
|
||||
|
||||
#### Misc
|
||||
|
||||
- Always use `let` for $derived state variables
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
# Automated Tests
|
||||
|
||||
## General rules
|
||||
|
||||
- NEVER include any test code in the production code - we should always have it in a separate dedicated files
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
## TypeScript
|
||||
|
||||
- Add JSDocs for functions
|
||||
|
|
@ -45,7 +45,7 @@ SRC=`pwd`
|
|||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
|
||||
|
||||
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_CUDA} ]; then
|
||||
|
|
|
|||
|
|
@ -2393,7 +2393,10 @@ class SmolVLMModel(MmprojModel):
|
|||
return [] # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Llama4ForCausalLM",
|
||||
)
|
||||
class Llama4Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA4
|
||||
undo_permute = False
|
||||
|
|
@ -2411,6 +2414,10 @@ class Llama4Model(LlamaModel):
|
|||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
if "layer_types" in self.hparams:
|
||||
if all(lt == "full_attention" for lt in self.hparams["layer_types"]):
|
||||
# all layers are full attention (for MobileLLM), disable swa
|
||||
self.gguf_writer.add_sliding_window(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("language_model."):
|
||||
|
|
|
|||
|
|
@ -190,7 +190,6 @@ option(GGML_WEBGPU "ggml: use WebGPU"
|
|||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ extern "C" {
|
|||
// user-code should use only these functions
|
||||
//
|
||||
|
||||
// TODO: remove in the future
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
|
||||
|
||||
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||
|
|
|
|||
|
|
@ -284,19 +284,19 @@ __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexc
|
|||
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
|
||||
//
|
||||
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
|
||||
const type prefix##0 = (pointer)->array[0]; \
|
||||
const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
|
||||
GGML_UNUSED(prefix##0);
|
||||
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
|
||||
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
|
||||
const type prefix##1 = (pointer)->array[1]; \
|
||||
const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
|
||||
GGML_UNUSED(prefix##1);
|
||||
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
|
||||
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
|
||||
const type prefix##2 = (pointer)->array[2]; \
|
||||
const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
|
||||
GGML_UNUSED(prefix##2);
|
||||
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
|
||||
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
|
||||
const type prefix##3 = (pointer)->array[3]; \
|
||||
const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
|
||||
GGML_UNUSED(prefix##3);
|
||||
|
||||
#define GGML_TENSOR_UNARY_OP_LOCALS \
|
||||
|
|
|
|||
|
|
@ -1728,7 +1728,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|||
ggml_cann_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl;
|
||||
ggml_cann_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
|
|
|
|||
|
|
@ -75,6 +75,8 @@
|
|||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
|
||||
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
|
||||
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
// Moore Threads
|
||||
|
|
@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
|||
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
}
|
||||
|
||||
// Maximum number of bytes that can be copied in a single instruction.
|
||||
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
|
||||
#ifdef GGML_USE_HIP
|
||||
return 16;
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
return 16;
|
||||
#else
|
||||
return 8;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
|
|
|
|||
|
|
@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
}
|
||||
|
||||
template<int D> // D == head size
|
||||
#if !defined(GGML_USE_HIP)
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP)
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
|
|
@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
|
|||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
const float KQ_max_scale = expf(meta[l].x - kqmax);
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
|
|
@ -836,11 +831,10 @@ void launch_fattn(
|
|||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
int parallel_blocks = max_blocks_per_sm;
|
||||
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
|
|
@ -862,9 +856,6 @@ void launch_fattn(
|
|||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,20 +2,30 @@
|
|||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
|
||||
#define FATTN_TILE_NTHREADS 256
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
} else {
|
||||
return 64;
|
||||
}
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
|
|
@ -49,24 +59,28 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
|||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
return 128;
|
||||
case 128:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
#else
|
||||
return 64;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
case 256:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
#else
|
||||
return 64;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
|
|
@ -100,17 +114,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
|||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return 64;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
case 256:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return ncols <= 16 ? 64 : 256;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -120,9 +125,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
|||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 256:
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -142,12 +146,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
|||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
#ifdef GGML_USE_HIP
|
||||
__launch_bounds__(FATTN_TILE_NTHREADS, 1)
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
__launch_bounds__(FATTN_TILE_NTHREADS, 2)
|
||||
#endif // GGML_USE_HIP
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
|
|
@ -193,7 +212,7 @@ static __global__ void flash_attn_tile(
|
|||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
|
|
@ -206,90 +225,126 @@ static __global__ void flash_attn_tile(
|
|||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
constexpr int cpy_nb = 16;
|
||||
#else
|
||||
constexpr int cpy_nb = 8;
|
||||
#endif // defined(GGML_USE_HIP) && defined(GCN)
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
__shared__ float KQ[ncols][kq_stride];
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
|
||||
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
|
||||
float kqmax[ncols/nwarps];
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float kqsum[ncols/nwarps] = {0.0f};
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
|
||||
Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new[ncols/nwarps];
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
kqmax_new[j] = kqmax[j];
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
|
||||
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
|
||||
#else
|
||||
const float2 tmp_f2 = __half22float2(tmp_h2);
|
||||
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
|
||||
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
|
@ -298,12 +353,12 @@ static __global__ void flash_attn_tile(
|
|||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[ncols/nwarps][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[ncols/nwarps][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
|
|
@ -311,29 +366,29 @@ static __global__ void flash_attn_tile(
|
|||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -344,104 +399,77 @@ static __global__ void flash_attn_tile(
|
|||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
||||
|
||||
KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
|
||||
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||
|
||||
float kqsum_add = 0.0f;
|
||||
if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
|
||||
const int i = i0 + 4*threadIdx.x;
|
||||
|
||||
float4 val = *(const float4 *) &KQ[j][i];
|
||||
val.x = expf(val.x - kqmax[j0/nwarps]);
|
||||
val.y = expf(val.y - kqmax[j0/nwarps]);
|
||||
val.z = expf(val.z - kqmax[j0/nwarps]);
|
||||
val.w = expf(val.w - kqmax[j0/nwarps]);
|
||||
kqsum_add += val.x + val.y + val.z + val.w;
|
||||
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
} else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
float2 val = *(const float2 *) &KQ[j][i];
|
||||
val.x = expf(val.x - kqmax[j0/nwarps]);
|
||||
val.y = expf(val.y - kqmax[j0/nwarps]);
|
||||
kqsum_add += val.x + val.y;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 tmp = make_half2(val.x, val.y);
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ[j][i] - kqmax[j0/nwarps];
|
||||
const float val = expf(diff);
|
||||
kqsum_add += val;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
((half *) KQ[j])[i] = val;
|
||||
#else
|
||||
KQ[j][i] = val;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
}
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
|
|
@ -449,65 +477,96 @@ static __global__ void flash_attn_tile(
|
|||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KV_tmp_h2[k_tile*(D/2) + i] = tmp;
|
||||
#else
|
||||
KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[ncols/nwarps];
|
||||
#else
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[ncols/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
|
||||
#else
|
||||
V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
|
||||
#else
|
||||
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
|
||||
#else
|
||||
VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
|
||||
VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
|
@ -519,69 +578,92 @@ static __global__ void flash_attn_tile(
|
|||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
||||
kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
|
||||
kqmax[j0/nwarps] = kqmax_new_j;
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps] += val;
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += warp_size) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
|
@ -602,15 +684,29 @@ template <int D, bool use_logit_softcap>
|
|||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
|
|
@ -619,6 +715,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
|
|||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
|
|
|
|||
|
|
@ -158,41 +158,41 @@
|
|||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN
|
||||
#endif
|
||||
|
||||
#if defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN5
|
||||
#endif
|
||||
#endif // defined(__gfx900__) || defined(__gfx906__)
|
||||
|
||||
#if defined(__gfx803__)
|
||||
#define GCN4
|
||||
#endif
|
||||
#endif // defined(__gfx803__)
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
#define CDNA // For the entire family
|
||||
#endif
|
||||
#if defined(GCN5) || defined(GCN4)
|
||||
#define GCN
|
||||
#endif // defined(GCN5) || defined(GCN4)
|
||||
|
||||
#if defined(__gfx942__)
|
||||
#define CDNA3
|
||||
#endif
|
||||
#endif // defined(__gfx942__)
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
#define CDNA2
|
||||
#endif
|
||||
#endif // defined(__gfx90a__)
|
||||
|
||||
#if defined(__gfx908__)
|
||||
#define CDNA1
|
||||
#endif
|
||||
#endif // defined(__gfx908__)
|
||||
|
||||
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
#define CDNA // For the entire family
|
||||
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
|
||||
#if defined(__GFX12__)
|
||||
#define RDNA4
|
||||
#endif
|
||||
#endif // defined(__GFX12__)
|
||||
|
||||
#if defined(__GFX11__)
|
||||
#define RDNA3
|
||||
#endif
|
||||
#endif // defined(__GFX11__)
|
||||
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||
|
|
@ -201,7 +201,11 @@
|
|||
|
||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||
#define RDNA1
|
||||
#endif
|
||||
#endif // defined(__gfx1010__) || defined(__gfx1012__)
|
||||
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
#define RDNA // For the entire family
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
|||
message(STATUS "Metal framework found")
|
||||
|
||||
ggml_add_backend_library(ggml-metal
|
||||
ggml-metal.m
|
||||
ggml-metal.cpp
|
||||
ggml-metal-device.m
|
||||
ggml-metal-device.cpp
|
||||
ggml-metal-common.cpp
|
||||
ggml-metal-context.m
|
||||
ggml-metal-ops.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(ggml-metal PRIVATE
|
||||
|
|
@ -19,10 +23,6 @@ if (GGML_METAL_NDEBUG)
|
|||
add_compile_definitions(GGML_METAL_NDEBUG)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL_USE_BF16)
|
||||
add_compile_definitions(GGML_METAL_USE_BF16)
|
||||
endif()
|
||||
|
||||
# copy metal files to bin directory
|
||||
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ struct ggml_mem_ranges {
|
|||
int debug = 0;
|
||||
};
|
||||
|
||||
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
|
||||
ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
|
||||
auto * res = new ggml_mem_ranges;
|
||||
|
||||
res->ranges.reserve(256);
|
||||
|
|
@ -31,15 +31,15 @@ struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
|
|||
return res;
|
||||
}
|
||||
|
||||
void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
|
||||
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
|
||||
delete mrs;
|
||||
}
|
||||
|
||||
void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
|
||||
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
|
||||
mrs->ranges.clear();
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
|
||||
static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
|
||||
mrs->ranges.push_back(mr);
|
||||
|
||||
return true;
|
||||
|
|
@ -87,7 +87,7 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
|
|||
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
|
||||
|
|
@ -99,7 +99,7 @@ static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * t
|
|||
return ggml_mem_ranges_add(mrs, mr);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
|
||||
|
|
@ -111,7 +111,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * t
|
|||
return ggml_mem_ranges_add(mrs, mr);
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->src[i]) {
|
||||
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
||||
|
|
@ -121,7 +121,7 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
|||
return ggml_mem_ranges_add_dst(mrs, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
|
||||
static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
|
||||
for (size_t i = 0; i < mrs->ranges.size(); i++) {
|
||||
const auto & cmp = mrs->ranges[i];
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
|
||||
|
|
@ -162,7 +162,7 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te
|
|||
return res;
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
|
||||
|
|
@ -172,7 +172,7 @@ static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_te
|
|||
return res;
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
||||
|
|
@ -222,7 +222,7 @@ struct node_info {
|
|||
|
||||
static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
|
||||
// helper to add node src and dst ranges
|
||||
const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
|
||||
const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node.node->src[i]) {
|
||||
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
|
||||
|
|
@ -246,7 +246,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
|||
};
|
||||
|
||||
// helper to check if a node can run concurrently with the existing set of nodes
|
||||
const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
|
||||
const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node.node->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
|
||||
|
|
@ -301,10 +301,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
|||
std::vector<bool> used(n, false);
|
||||
|
||||
// the memory ranges for the set of currently concurrent nodes
|
||||
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
|
||||
ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);
|
||||
|
||||
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
|
||||
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
|
||||
ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);
|
||||
|
||||
for (int i0 = 0; i0 < n; i0++) {
|
||||
if (used[i0]) {
|
||||
|
|
@ -375,7 +375,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
|||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_graph_optimize(ggml_cgraph * gf) {
|
||||
void ggml_graph_optimize(ggml_cgraph * gf) {
|
||||
constexpr int MAX_FUSE = 16;
|
||||
|
||||
const int n = gf->n_nodes;
|
||||
|
|
|
|||
|
|
@ -25,27 +25,27 @@ enum ggml_mem_range_type {
|
|||
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
|
||||
// tasks already in the set)
|
||||
//
|
||||
struct ggml_mem_ranges;
|
||||
typedef struct ggml_mem_ranges * ggml_mem_ranges_t;
|
||||
|
||||
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
|
||||
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
|
||||
ggml_mem_ranges_t ggml_mem_ranges_init(int debug);
|
||||
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs);
|
||||
|
||||
// remove all ranges from the set
|
||||
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
|
||||
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs);
|
||||
|
||||
// add src or dst ranges to track
|
||||
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
|
||||
|
||||
// return false if:
|
||||
// - new src range overlaps with any existing dst range
|
||||
// - new dst range overlaps with any existing range (src or dst)
|
||||
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
|
||||
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
|
||||
|
||||
// reorder the nodes in the graph to improve concurrency, while respecting fusion
|
||||
//
|
||||
// note: this implementation is generic and not specific to metal
|
||||
// if it proves to work well, we can start using it for other backends in the future
|
||||
void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
|
||||
void ggml_graph_optimize(struct ggml_cgraph * gf);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml-metal-device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//
|
||||
// backend context
|
||||
//
|
||||
|
||||
typedef struct ggml_metal * ggml_metal_t;
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev);
|
||||
void ggml_metal_free(ggml_metal_t ctx);
|
||||
|
||||
void ggml_metal_synchronize(ggml_metal_t ctx);
|
||||
|
||||
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);
|
||||
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);
|
||||
|
||||
void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb);
|
||||
void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);
|
||||
bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
|
||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,575 @@
|
|||
#import "ggml-metal-context.h"
|
||||
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
|
||||
#import "ggml-metal-impl.h"
|
||||
#import "ggml-metal-common.h"
|
||||
#import "ggml-metal-ops.h"
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
// max number of MTLCommandBuffer used to submit a graph for processing
|
||||
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
||||
|
||||
struct ggml_metal_command_buffer {
|
||||
id<MTLCommandBuffer> obj;
|
||||
};
|
||||
|
||||
struct ggml_metal {
|
||||
id<MTLDevice> device;
|
||||
id<MTLCommandQueue> queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_library_t lib;
|
||||
|
||||
dispatch_queue_t d_queue;
|
||||
|
||||
// additional, inference-time compiled pipelines
|
||||
ggml_metal_pipelines_t pipelines_ext;
|
||||
|
||||
bool use_bfloat;
|
||||
bool use_fusion;
|
||||
bool use_concurrency;
|
||||
bool use_graph_optimize;
|
||||
|
||||
int debug_graph;
|
||||
int debug_fusion;
|
||||
|
||||
// how many times a given op was fused
|
||||
uint64_t fuse_cnt[GGML_OP_COUNT];
|
||||
|
||||
// capture state
|
||||
bool capture_next_compute;
|
||||
bool capture_started;
|
||||
|
||||
id<MTLCaptureScope> capture_scope;
|
||||
|
||||
// command buffer state
|
||||
int n_cb; // number of extra threads used to submit the command buffers
|
||||
int n_nodes_0; // number of nodes submitted by the main thread
|
||||
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
|
||||
int n_nodes_per_cb;
|
||||
|
||||
struct ggml_cgraph * gf;
|
||||
|
||||
// the callback given to the thread pool
|
||||
void (^encode_async)(size_t ith);
|
||||
|
||||
// n_cb command buffers + 1 used by the main thread
|
||||
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||
|
||||
// extra command buffers for things like getting, setting and copying tensors
|
||||
NSMutableArray * cmd_bufs_ext;
|
||||
|
||||
// the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
|
||||
id<MTLCommandBuffer> cmd_buf_last;
|
||||
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
};
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
GGML_LOG_INFO("%s: allocating\n", __func__);
|
||||
|
||||
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
||||
// Show all the Metal device instances in the system
|
||||
NSArray * devices = MTLCopyAllDevices();
|
||||
for (id<MTLDevice> device in devices) {
|
||||
GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
||||
}
|
||||
[devices release]; // since it was created by a *Copy* C method
|
||||
#endif
|
||||
|
||||
// init context
|
||||
ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
|
||||
|
||||
res->device = ggml_metal_device_get_obj(dev);
|
||||
|
||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]);
|
||||
|
||||
// TODO: would it be better to have one queue for the backend and one queue for the device?
|
||||
// the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
|
||||
//res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
|
||||
res->queue = ggml_metal_device_get_queue(dev);
|
||||
if (res->queue == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
res->dev = dev;
|
||||
res->lib = ggml_metal_device_get_library(dev);
|
||||
if (res->lib == NULL) {
|
||||
GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
|
||||
GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
|
||||
|
||||
res->lib = ggml_metal_library_init(dev);
|
||||
if (res->lib == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
|
||||
|
||||
free(res);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
|
||||
|
||||
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
res->use_bfloat = props_dev->has_bfloat;
|
||||
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
|
||||
res->debug_graph = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
res->debug_fusion = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
res->use_graph_optimize = true;
|
||||
|
||||
if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
|
||||
res->use_graph_optimize = false;
|
||||
}
|
||||
|
||||
memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
|
||||
|
||||
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
|
||||
|
||||
res->capture_next_compute = false;
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
res->gf = nil;
|
||||
res->encode_async = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
res->cmd_bufs[i].obj = nil;
|
||||
}
|
||||
|
||||
res->cmd_bufs_ext = [[NSMutableArray alloc] init];
|
||||
|
||||
res->cmd_buf_last = nil;
|
||||
|
||||
res->pipelines_ext = ggml_metal_pipelines_init();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_free(ggml_metal_t ctx) {
|
||||
GGML_LOG_INFO("%s: deallocating\n", __func__);
|
||||
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
if (ctx->cmd_bufs[i].obj) {
|
||||
[ctx->cmd_bufs[i].obj release];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
|
||||
if (ctx->cmd_bufs_ext[i]) {
|
||||
[ctx->cmd_bufs_ext[i] release];
|
||||
}
|
||||
}
|
||||
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
[ctx->cmd_bufs_ext release];
|
||||
|
||||
if (ctx->pipelines_ext) {
|
||||
ggml_metal_pipelines_free(ctx->pipelines_ext);
|
||||
ctx->pipelines_ext = nil;
|
||||
}
|
||||
|
||||
if (ctx->debug_fusion > 0) {
|
||||
GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
|
||||
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
||||
if (ctx->fuse_cnt[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// note: cannot use ggml_log here
|
||||
GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Block_release(ctx->encode_async);
|
||||
|
||||
//[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
|
||||
|
||||
dispatch_release(ctx->d_queue);
|
||||
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
// wait for any backend operations to finish
|
||||
if (ctx->cmd_buf_last) {
|
||||
[ctx->cmd_buf_last waitUntilCompleted];
|
||||
ctx->cmd_buf_last = nil;
|
||||
}
|
||||
|
||||
// release any completed command buffers
|
||||
if (ctx->cmd_bufs_ext.count > 0) {
|
||||
for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
[cmd_buf release];
|
||||
}
|
||||
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return (struct ggml_metal_buffer_id) { nil, 0 };
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
||||
|
||||
return ggml_metal_buffer_get_id(buffer->context, t);
|
||||
}
|
||||
|
||||
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
@autoreleasepool {
|
||||
// wrap the source data into a Metal buffer
|
||||
id<MTLBuffer> buf_src = [ctx->device newBufferWithBytes:data
|
||||
length:size
|
||||
options:MTLResourceStorageModeShared];
|
||||
|
||||
struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);
|
||||
if (bid_dst.metal == nil) {
|
||||
GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
|
||||
}
|
||||
|
||||
bid_dst.offs += offset;
|
||||
|
||||
// queue the copy operation into the queue of the Metal context
|
||||
// this will be queued at the end, after any currently ongoing GPU operations
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:buf_src
|
||||
sourceOffset:0
|
||||
toBuffer:bid_dst.metal
|
||||
destinationOffset:bid_dst.offs
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
[cmd_buf commit];
|
||||
|
||||
// do not wait here for completion
|
||||
//[cmd_buf waitUntilCompleted];
|
||||
|
||||
// instead, remember a reference to the command buffer and wait for it later if needed
|
||||
[ctx->cmd_bufs_ext addObject:cmd_buf];
|
||||
ctx->cmd_buf_last = cmd_buf;
|
||||
|
||||
[cmd_buf retain];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
@autoreleasepool {
|
||||
id<MTLBuffer> buf_dst = [ctx->device newBufferWithBytesNoCopy:data
|
||||
length:size
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
|
||||
struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);
|
||||
if (bid_src.metal == nil) {
|
||||
GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
|
||||
}
|
||||
|
||||
bid_src.offs += offset;
|
||||
|
||||
// queue the copy operation into the queue of the Metal context
|
||||
// this will be queued at the end, after any currently ongoing GPU operations
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:bid_src.metal
|
||||
sourceOffset:bid_src.offs
|
||||
toBuffer:buf_dst
|
||||
destinationOffset:0
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
[cmd_buf commit];
|
||||
|
||||
// do not wait here for completion
|
||||
//[cmd_buf waitUntilCompleted];
|
||||
|
||||
// instead, remember a reference to the command buffer and wait for it later if needed
|
||||
[ctx->cmd_bufs_ext addObject:cmd_buf];
|
||||
ctx->cmd_buf_last = cmd_buf;
|
||||
|
||||
[cmd_buf retain];
|
||||
}
|
||||
}
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = 64;
|
||||
|
||||
// number of threads in addition to the main thread
|
||||
const int n_cb = ctx->n_cb;
|
||||
|
||||
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
||||
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
||||
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
||||
// each thread creates it's own command buffer and enqueues the ops in parallel
|
||||
//
|
||||
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
|
||||
|
||||
@autoreleasepool {
|
||||
ctx->gf = gf;
|
||||
|
||||
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
|
||||
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
|
||||
|
||||
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
||||
|
||||
const bool use_capture = ctx->capture_next_compute;
|
||||
if (use_capture) {
|
||||
ctx->capture_next_compute = false;
|
||||
|
||||
// make sure all previous computations have finished before starting the capture
|
||||
if (ctx->cmd_buf_last) {
|
||||
[ctx->cmd_buf_last waitUntilCompleted];
|
||||
ctx->cmd_buf_last = nil;
|
||||
}
|
||||
|
||||
if (!ctx->capture_started) {
|
||||
// create capture scope
|
||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
|
||||
|
||||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||
descriptor.captureObject = ctx->capture_scope;
|
||||
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
||||
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
||||
|
||||
NSError * error = nil;
|
||||
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
||||
GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
||||
} else {
|
||||
[ctx->capture_scope beginScope];
|
||||
ctx->capture_started = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// the main thread commits the first few commands immediately
|
||||
// cmd_buf[n_cb]
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
[cmd_buf retain];
|
||||
|
||||
if (ctx->cmd_bufs[n_cb].obj) {
|
||||
[ctx->cmd_bufs[n_cb].obj release];
|
||||
}
|
||||
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||
|
||||
[cmd_buf enqueue];
|
||||
|
||||
ctx->encode_async(n_cb);
|
||||
}
|
||||
|
||||
// remember the command buffer for the next iteration
|
||||
ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
|
||||
|
||||
// prepare the rest of the command buffers asynchronously (optional)
|
||||
// cmd_buf[0.. n_cb)
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
[cmd_buf retain];
|
||||
|
||||
if (ctx->cmd_bufs[cb_idx].obj) {
|
||||
[ctx->cmd_bufs[cb_idx].obj release];
|
||||
}
|
||||
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||
|
||||
// always enqueue the first two command buffers
|
||||
// enqueue all of the command buffers if we don't need to abort
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[cmd_buf enqueue];
|
||||
|
||||
// update the pointer to the last queued command buffer
|
||||
// this is needed to implement synchronize()
|
||||
ctx->cmd_buf_last = cmd_buf;
|
||||
}
|
||||
}
|
||||
|
||||
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
||||
|
||||
// for debugging: block until graph is computed
|
||||
//[ctx->cmd_buf_last waitUntilCompleted];
|
||||
|
||||
// enter here only when capturing in order to wait for all computation to finish
|
||||
// otherwise, we leave the graph to compute asynchronously
|
||||
if (!use_capture && ctx->capture_started) {
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_cb; ++i) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||
if (!next_buffer) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
||||
if (next_queued) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
||||
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
||||
return GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
[next_buffer commit];
|
||||
}
|
||||
|
||||
[ctx->capture_scope endScope];
|
||||
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
||||
}
|
||||
}
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
if (ctx->use_graph_optimize) {
|
||||
ggml_graph_optimize(gf);
|
||||
}
|
||||
|
||||
//printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
|
||||
void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
||||
if (ctx->n_cb != n_cb) {
|
||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
|
||||
|
||||
if (ctx->n_cb > 2) {
|
||||
GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->encode_async) {
|
||||
Block_release(ctx->encode_async);
|
||||
}
|
||||
|
||||
ctx->encode_async = Block_copy(^(size_t iter) {
|
||||
const int cb_idx = iter;
|
||||
const int n_cb_l = ctx->n_cb;
|
||||
|
||||
const int n_nodes_0 = ctx->n_nodes_0;
|
||||
const int n_nodes_1 = ctx->n_nodes_1;
|
||||
|
||||
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||
|
||||
int idx_start = 0;
|
||||
int idx_end = n_nodes_0;
|
||||
|
||||
if (cb_idx < n_cb_l) {
|
||||
idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
||||
idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
||||
|
||||
ggml_metal_op_t ctx_op = ggml_metal_op_init(
|
||||
ctx->dev,
|
||||
cmd_buf,
|
||||
ctx->gf,
|
||||
idx_start,
|
||||
idx_end,
|
||||
ctx->use_fusion,
|
||||
ctx->use_concurrency,
|
||||
ctx->capture_next_compute,
|
||||
ctx->debug_graph,
|
||||
ctx->debug_fusion);
|
||||
|
||||
for (int idx = idx_start; idx < idx_end;) {
|
||||
const int res = ggml_metal_op_encode(ctx_op, idx);
|
||||
if (res == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
idx += res;
|
||||
}
|
||||
|
||||
ggml_metal_op_free(ctx_op);
|
||||
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[cmd_buf commit];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {
|
||||
ctx->abort_callback = abort_callback;
|
||||
ctx->abort_callback_data = user_data;
|
||||
}
|
||||
|
||||
bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
||||
GGML_ASSERT(ctx->device != nil);
|
||||
|
||||
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||
}
|
||||
|
||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
||||
ctx->capture_next_compute = true;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,227 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct ggml_metal_buffer_id {
|
||||
void * metal; // id<MTLBuffer>
|
||||
size_t offs;
|
||||
};
|
||||
|
||||
typedef struct ggml_metal_device * ggml_metal_device_t;
|
||||
|
||||
//
|
||||
// MTLFunctionConstantValues wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_cv * ggml_metal_cv_t;
|
||||
|
||||
ggml_metal_cv_t ggml_metal_cv_init(void);
|
||||
void ggml_metal_cv_free(ggml_metal_cv_t cv);
|
||||
|
||||
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
|
||||
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
|
||||
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
|
||||
|
||||
//
|
||||
// MTLComputePipelineState wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
|
||||
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
|
||||
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
|
||||
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
|
||||
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
|
||||
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
// a collection of pipelines
|
||||
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
|
||||
|
||||
ggml_metal_pipelines_t ggml_metal_pipelines_init(void);
|
||||
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
|
||||
|
||||
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
|
||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
|
||||
|
||||
//
|
||||
// MTLCommandBuffer wrapper
|
||||
//
|
||||
|
||||
typedef void * ggml_metal_cmd_buf_t;
|
||||
|
||||
//
|
||||
// MTLComputeCommandEncoder wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_encoder * ggml_metal_encoder_t;
|
||||
|
||||
ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent);
|
||||
void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
|
||||
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
|
||||
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
|
||||
|
||||
void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx);
|
||||
|
||||
void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2);
|
||||
|
||||
void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
|
||||
|
||||
//
|
||||
// MTLLibrary wrapper
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_library * ggml_metal_library_t;
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev);
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t dv,
|
||||
int32_t nwg);
|
||||
|
||||
//
|
||||
// device
|
||||
//
|
||||
|
||||
struct ggml_metal_device_props {
|
||||
char name[128];
|
||||
|
||||
size_t max_buffer_size;
|
||||
size_t max_working_set_size;
|
||||
size_t max_theadgroup_memory_size;
|
||||
|
||||
bool has_simdgroup_reduction;
|
||||
bool has_simdgroup_mm;
|
||||
bool has_unified_memory;
|
||||
bool has_bfloat;
|
||||
bool use_residency_sets;
|
||||
bool use_shared_buffers;
|
||||
|
||||
bool supports_gpu_family_apple7;
|
||||
};
|
||||
|
||||
ggml_metal_device_t ggml_metal_device_init(void);
|
||||
void ggml_metal_device_free(ggml_metal_device_t dev);
|
||||
|
||||
// return a singleton that is automatically destroyed when the program exits
|
||||
ggml_metal_device_t ggml_metal_device_get(void);
|
||||
|
||||
void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id<MTLDevice>
|
||||
void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue>
|
||||
|
||||
ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);
|
||||
|
||||
void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
|
||||
bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
|
||||
|
||||
const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev);
|
||||
|
||||
//
|
||||
// device buffers
|
||||
//
|
||||
|
||||
typedef struct ggml_metal_buffer * ggml_metal_buffer_t;
|
||||
|
||||
ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared);
|
||||
ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size);
|
||||
|
||||
void ggml_metal_buffer_free (ggml_metal_buffer_t buf);
|
||||
void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf);
|
||||
bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
|
||||
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||
// Metal buffer based on the host memory pointer
|
||||
//
|
||||
struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -8,6 +8,9 @@
|
|||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_F 2
|
||||
#define N_SG_F 4
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
|
|
@ -72,6 +75,7 @@
|
|||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
|
|
@ -165,6 +169,16 @@ typedef struct {
|
|||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
float scale;
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
|
@ -453,7 +467,7 @@ typedef struct {
|
|||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t n_groups;
|
||||
int32_t ngrp;
|
||||
float eps;
|
||||
} ggml_metal_kargs_group_norm;
|
||||
|
||||
|
|
@ -506,14 +520,6 @@ typedef struct {
|
|||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne10;
|
||||
int64_t ne11;
|
||||
int64_t ne12;
|
||||
int64_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
|
|
@ -547,12 +553,6 @@ typedef struct {
|
|||
int32_t n_head_log2;
|
||||
} ggml_metal_kargs_soft_max;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int n_past;
|
||||
} ggml_metal_kargs_diag_mask_inf;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
|
@ -579,7 +579,7 @@ typedef struct {
|
|||
int64_t n_group;
|
||||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
int64_t s_off;
|
||||
uint64_t s_off;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
|
|
@ -719,7 +719,12 @@ typedef struct {
|
|||
int64_t IW;
|
||||
int64_t OH;
|
||||
int64_t OW;
|
||||
int64_t parallel_elements;
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_pool_2d;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
uint64_t nb01;
|
||||
} ggml_metal_kargs_argmax;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,81 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml-metal-device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct ggml_metal_op * ggml_metal_op_t;
|
||||
|
||||
ggml_metal_op_t ggml_metal_op_init(
|
||||
ggml_metal_device_t dev,
|
||||
ggml_metal_cmd_buf_t cmd_buf,
|
||||
struct ggml_cgraph * gf,
|
||||
int idx_start,
|
||||
int idx_end,
|
||||
bool use_fusion,
|
||||
bool use_concurrency,
|
||||
bool use_capture,
|
||||
int debug_graph,
|
||||
int debug_fusion);
|
||||
|
||||
void ggml_metal_op_free(ggml_metal_op_t ctx);
|
||||
|
||||
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
|
||||
|
||||
//
|
||||
// available ops:
|
||||
//
|
||||
|
||||
// tokens per expert
|
||||
size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op);
|
||||
|
||||
// id map [n_tokens, n_expert]
|
||||
size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
||||
|
||||
// return true if we should use the FA vector kernel for this op
|
||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,718 @@
|
|||
#include "ggml-metal.h"
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml-metal-device.h"
|
||||
#include "ggml-metal-context.h"
|
||||
#include "ggml-metal-ops.h"
|
||||
|
||||
// globals
|
||||
|
||||
// initialized in ggml_backend_metal_reg
|
||||
static ggml_backend_reg g_ggml_metal_reg;
|
||||
static ggml_backend_device g_ggml_metal_device;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// backend interface
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// shared buffer
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_free(ctx);
|
||||
}
|
||||
|
||||
static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
return ggml_metal_buffer_get_base(ctx);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_clear(ctx, value);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
// private buffer
|
||||
|
||||
static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_free(ctx);
|
||||
}
|
||||
|
||||
static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
return ggml_metal_buffer_get_base(ctx);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
ggml_metal_buffer_clear(ctx, value);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
//
|
||||
// buffer types
|
||||
//
|
||||
|
||||
// common method for allocating shread or private Metal buffers
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared);
|
||||
|
||||
ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res)
|
||||
? ggml_backend_metal_buffer_shared_i
|
||||
: ggml_backend_metal_buffer_private_i;
|
||||
|
||||
return ggml_backend_buffer_init(buft, buf_i, res, size);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
size_t res = ggml_nbytes(tensor);
|
||||
|
||||
// some operations require additional memory for fleeting data:
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
res += ggml_metal_op_mul_mat_id_extra_tpe(tensor);
|
||||
res += ggml_metal_op_mul_mat_id_extra_ids(tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
// default (shared) buffer type
|
||||
|
||||
static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
|
||||
static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
|
||||
},
|
||||
/* .device = */ &g_ggml_metal_device,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_metal;
|
||||
}
|
||||
|
||||
// default (private) buffer type
|
||||
|
||||
static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal_Private";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
|
||||
static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_metal_buffer_type_private_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
|
||||
},
|
||||
/* .device = */ &g_ggml_metal_device,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_metal;
|
||||
}
|
||||
|
||||
// mapped buffer type
|
||||
|
||||
static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal_Mapped";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
// for mapped buffers, prefer shared memory
|
||||
return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
|
||||
// note: not obvious, but this buffer type still needs to implement .alloc_buffer:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
|
||||
static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
|
||||
},
|
||||
/* .device = */ &g_ggml_metal_device,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_mapped_metal;
|
||||
}
|
||||
|
||||
// backend
|
||||
|
||||
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
// wait for any ongoing async operations to finish
|
||||
ggml_metal_synchronize(ctx);
|
||||
|
||||
ggml_metal_free(ctx);
|
||||
|
||||
free(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_synchronize(ctx);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_set_tensor_async(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_get_tensor_async(ctx, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(backend_src);
|
||||
GGML_UNUSED(backend_dst);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
return ggml_metal_graph_compute(ctx, cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_graph_optimize(ctx, cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_set_n_cb(ctx, n_cb);
|
||||
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_metal_i = {
|
||||
/* .get_name = */ ggml_backend_metal_name,
|
||||
/* .free = */ ggml_backend_metal_free,
|
||||
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
|
||||
/* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
|
||||
/* .synchronize = */ ggml_backend_metal_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
||||
|
||||
// the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal
|
||||
// in any case, these docs seem relevant if we ever decide to implement it:
|
||||
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ ggml_backend_metal_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_metal_guid(void) {
|
||||
static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
|
||||
return &guid;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_metal_init(void) {
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_t ctx = ggml_metal_init(ctx_dev);
|
||||
if (ctx == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
|
||||
|
||||
*backend = {
|
||||
/* .guid = */ ggml_backend_metal_guid(),
|
||||
/* .interface = */ ggml_backend_metal_i,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
ggml_backend_metal_set_n_cb(backend, 1);
|
||||
|
||||
return backend;
|
||||
}
|
||||
|
||||
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
||||
}
|
||||
|
||||
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_set_abort_callback(ctx, abort_callback, user_data);
|
||||
}
|
||||
|
||||
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
return ggml_metal_supports_family(ctx, family);
|
||||
}
|
||||
|
||||
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
ggml_metal_t ctx = (ggml_metal_t)backend->context;
|
||||
|
||||
ggml_metal_capture_next_compute(ctx);
|
||||
}
|
||||
|
||||
// backend device
|
||||
|
||||
static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
return ggml_metal_device_get_props(ctx_dev)->name;
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_device_get_memory(ctx_dev, free, total);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_metal_device_get_name(dev);
|
||||
props->description = ggml_backend_metal_device_get_description(dev);
|
||||
props->type = ggml_backend_metal_device_get_type(dev);
|
||||
|
||||
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
props->caps = {
|
||||
/* .async = */ true,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ true,
|
||||
/* .events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_t ctx = ggml_metal_init(ctx_dev);
|
||||
if (ctx == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
|
||||
|
||||
*backend = {
|
||||
/* .guid = */ ggml_backend_metal_guid(),
|
||||
/* .interface = */ ggml_backend_metal_i,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
ggml_backend_metal_set_n_cb(backend, 1);
|
||||
|
||||
return backend;
|
||||
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
|
||||
|
||||
return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private();
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);
|
||||
|
||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
|
||||
|
||||
return ggml_metal_device_supports_op(ctx_dev, op);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->ne[1];
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return op->ne[2];
|
||||
default:
|
||||
return ggml_nrows(op);
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
const int min_batch_size = 32;
|
||||
|
||||
return (op->op == GGML_OP_MUL_MAT ||
|
||||
op->op == GGML_OP_MUL_MAT_ID) &&
|
||||
get_op_batch_size(op) >= min_batch_size;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
static ggml_backend_device_i ggml_backend_metal_device_i = {
|
||||
/* .get_name = */ ggml_backend_metal_device_get_name,
|
||||
/* .get_description = */ ggml_backend_metal_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_metal_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_metal_device_get_type,
|
||||
/* .get_props = */ ggml_backend_metal_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_metal_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
|
||||
/* .supports_op = */ ggml_backend_metal_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_metal_device_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_metal_device_offload_op,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
// backend registry
|
||||
|
||||
static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return "Metal";
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
|
||||
return 1;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
|
||||
return &g_ggml_metal_device;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
}
|
||||
|
||||
static ggml_backend_feature g_ggml_backend_metal_features[] = {
|
||||
#if defined(GGML_METAL_EMBED_LIBRARY)
|
||||
{ "EMBED_LIBRARY", "1" },
|
||||
#endif
|
||||
{ NULL, NULL },
|
||||
};
|
||||
|
||||
static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) {
|
||||
return g_ggml_backend_metal_features;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (strcmp(name, "ggml_backend_get_features") == 0) {
|
||||
return (void *)ggml_backend_metal_get_features;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
||||
/* .get_name = */ ggml_backend_metal_reg_get_name,
|
||||
/* .device_count = */ ggml_backend_metal_reg_device_count,
|
||||
/* .device_get = */ ggml_backend_metal_reg_device_get,
|
||||
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
||||
{
|
||||
g_ggml_metal_reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_metal_reg_i,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
g_ggml_metal_device = {
|
||||
/* .iface = */ ggml_backend_metal_device_i,
|
||||
/* .reg = */ &g_ggml_metal_reg,
|
||||
/* .context = */ ggml_metal_device_get(),
|
||||
};
|
||||
}
|
||||
|
||||
return &g_ggml_metal_reg;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -116,6 +116,10 @@ struct webgpu_context_struct {
|
|||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
|
||||
// Separate this out from limits since on some Metal systems, the limit returned by
|
||||
// querying the limits is higher than the actual allowed maximum.
|
||||
uint32_t max_wg_size_x;
|
||||
|
||||
std::recursive_mutex mutex;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
|
|
@ -124,7 +128,15 @@ struct webgpu_context_struct {
|
|||
wgpu::ComputePipeline memset_pipeline;
|
||||
wgpu::ComputePipeline mul_mat_pipeline[30][2];
|
||||
wgpu::ComputePipeline set_rows_pipeline;
|
||||
wgpu::ComputePipeline get_rows_pipeline[30];
|
||||
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
||||
wgpu::ComputePipeline cpy_pipeline;
|
||||
wgpu::ComputePipeline add_pipeline[2];
|
||||
wgpu::ComputePipeline add_ip_pipeline[2];
|
||||
wgpu::ComputePipeline mul_pipeline[2];
|
||||
wgpu::ComputePipeline mul_ip_pipeline[2];
|
||||
wgpu::ComputePipeline rms_norm_pipeline;
|
||||
wgpu::ComputePipeline rms_norm_ip_pipeline;
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
|
||||
|
|
@ -232,14 +244,15 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
|||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
if (ctx->callback_futures.empty()) {
|
||||
// no existing callbacks, wait on queue submission
|
||||
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
||||
}
|
||||
}),
|
||||
UINT64_MAX);
|
||||
ctx->instance.WaitAny(
|
||||
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
|
||||
std::string(message).c_str());
|
||||
}
|
||||
}),
|
||||
UINT64_MAX);
|
||||
} else {
|
||||
// existing callbacks, wait on them
|
||||
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
|
||||
|
|
@ -286,10 +299,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
|||
// Check for errrors in SET_ROWS operations
|
||||
for (auto & error_bufs : staged_set_row_error_bufs) {
|
||||
wgpu::Future f = error_bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read,
|
||||
0,
|
||||
error_bufs.host_buf.GetSize(),
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
||||
|
|
@ -311,10 +321,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
|||
wgpu::MapMode mode,
|
||||
size_t offset,
|
||||
size_t size) {
|
||||
ctx->instance.WaitAny(buffer.MapAsync(mode,
|
||||
offset,
|
||||
size,
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
|
||||
|
|
@ -351,7 +358,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
|||
std::vector<uint32_t> params,
|
||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||
uint32_t wg_x,
|
||||
bool submit_and_wait = false) {
|
||||
const char * bind_group_label = nullptr,
|
||||
bool submit_and_wait = false) {
|
||||
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
||||
|
|
@ -372,6 +380,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
|||
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
|
||||
bind_group_desc.entryCount = bind_group_entries.size();
|
||||
bind_group_desc.entries = bind_group_entries.data();
|
||||
if (bind_group_label) {
|
||||
bind_group_desc.label = bind_group_label;
|
||||
}
|
||||
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
||||
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
|
|
@ -415,9 +426,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
|||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
|
||||
};
|
||||
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
|
||||
size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread;
|
||||
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
|
@ -461,26 +472,26 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
|
|||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||
}
|
||||
|
||||
// Used to determine if two tensors are the same for in-place operations
|
||||
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = { ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Logical shape — same for both tensors even if permuted
|
||||
(uint32_t) src->ne[0],
|
||||
(uint32_t) src->ne[1],
|
||||
(uint32_t) src->ne[2],
|
||||
(uint32_t) src->ne[3] };
|
||||
std::vector<uint32_t> params = {
|
||||
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Logical shape — same for both tensors even if permuted
|
||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
|
|
@ -493,9 +504,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||
|
|
@ -509,27 +520,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
error_bufs.host_buf.Unmap();
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Shape of src
|
||||
(uint32_t) src->ne[0],
|
||||
(uint32_t) src->ne[1],
|
||||
(uint32_t) src->ne[2],
|
||||
(uint32_t) src->ne[3],
|
||||
// Shape of idx
|
||||
(uint32_t) (idx->ne[1]),
|
||||
(uint32_t) (idx->ne[2]) };
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Shape of src
|
||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
|
||||
// Shape of idx
|
||||
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
|
|
@ -547,13 +552,55 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
|
||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
ctx->staged_set_row_error_bufs.push_back(error_bufs);
|
||||
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Shape of dst
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
|
||||
// Shape of idx
|
||||
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(idx),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
|
||||
wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type];
|
||||
if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) {
|
||||
pipeline = ctx->get_rows_f32_no_vec_pipeline;
|
||||
}
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
@ -593,7 +640,104 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
|||
|
||||
uint32_t wg_x =
|
||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
wgpu::ComputePipeline & pipeline,
|
||||
bool in_place) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) ggml_nelements(dst),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) src0->ne[2],
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||
};
|
||||
if (!in_place) {
|
||||
entries.push_back({ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool in_place = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
uint32_t eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
};
|
||||
if (!in_place) {
|
||||
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
|
||||
}
|
||||
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
|
||||
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
|
||||
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
|
||||
if (!in_place) {
|
||||
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
|
||||
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
|
||||
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
|
||||
}
|
||||
params.push_back((uint32_t) src->ne[0]);
|
||||
params.push_back((uint32_t) src->ne[1]);
|
||||
params.push_back((uint32_t) src->ne[2]);
|
||||
params.push_back((uint32_t) src->ne[3]);
|
||||
params.push_back(eps); // epsilon, will be bitcast to float in shader
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
||||
};
|
||||
if (!in_place) {
|
||||
entries.push_back({ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
wgpu::ComputePipeline pipeline;
|
||||
if (in_place) {
|
||||
pipeline = ctx->rms_norm_ip_pipeline;
|
||||
} else {
|
||||
pipeline = ctx->rms_norm_pipeline;
|
||||
}
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
|
|
@ -615,20 +759,34 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|||
case GGML_OP_RESHAPE:
|
||||
return false;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_webgpu_cpy(ctx, src0, node);
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_cpy(ctx, src0, node);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_webgpu_get_rows(ctx, src0, src1, node);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||
break;
|
||||
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -731,8 +889,8 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
|
||||
}
|
||||
// memset the remaining bytes
|
||||
ggml_backend_webgpu_buffer_memset(
|
||||
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
|
||||
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
|
||||
remaining_size);
|
||||
} else {
|
||||
// wait for WriteBuffer to complete
|
||||
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
|
||||
|
|
@ -766,11 +924,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
if (webgpu_ctx->get_tensor_staging_buf) {
|
||||
webgpu_ctx->get_tensor_staging_buf.Destroy();
|
||||
}
|
||||
ggml_webgpu_create_buffer(device,
|
||||
webgpu_ctx->get_tensor_staging_buf,
|
||||
final_size,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
||||
"get_tensor_staging_buf");
|
||||
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
|
||||
}
|
||||
|
||||
// Copy the data from the buffer to the staging buffer
|
||||
|
|
@ -824,8 +979,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
|
|||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
|
||||
wgpu::Buffer buf;
|
||||
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
|
||||
buf,
|
||||
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
|
||||
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
||||
"allocated_buffer");
|
||||
|
|
@ -890,9 +1044,17 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|||
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
||||
}
|
||||
|
||||
// The max workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->max_wg_size_x;
|
||||
return constants;
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
||||
// we use the maximum workgroup size for the memset pipeline
|
||||
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
size_t max_wg_size = webgpu_ctx->max_wg_size_x;
|
||||
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
|
||||
// Size the bytes_per_thread so that the largest buffer size can be handled
|
||||
webgpu_ctx->memset_bytes_per_thread =
|
||||
|
|
@ -906,109 +1068,142 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
|||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f32_f32,
|
||||
"mul_mat_f32_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||
wgsl_mul_mat_f16_f16,
|
||||
"mul_mat_f16_f16");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f16_f32,
|
||||
"mul_mat_f16_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_0_f32,
|
||||
"mul_mat_q4_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_1_f32,
|
||||
"mul_mat_q4_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_0_f32,
|
||||
"mul_mat_q5_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_1_f32,
|
||||
"mul_mat_q5_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q8_0_f32,
|
||||
"mul_mat_q8_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q2_k_f32,
|
||||
"mul_mat_q2_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q3_k_f32,
|
||||
"mul_mat_q3_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_k_f32,
|
||||
"mul_mat_q4_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_k_f32,
|
||||
"mul_mat_q5_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q6_k_f32,
|
||||
"mul_mat_q6_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xxs_f32,
|
||||
"mul_mat_iq2_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xs_f32,
|
||||
"mul_mat_iq2_xs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_s_f32,
|
||||
"mul_mat_iq2_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_xxs_f32,
|
||||
"mul_mat_iq3_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_s_f32,
|
||||
"mul_mat_iq3_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_s_f32,
|
||||
"mul_mat_iq1_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_m_f32,
|
||||
"mul_mat_iq1_m_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_nl_f32,
|
||||
"mul_mat_iq4_nl_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device,
|
||||
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32,
|
||||
"mul_mat_iq4_xs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
|
||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
|
||||
"get_rows_f32_vec", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
|
||||
"get_rows_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
|
||||
"get_rows_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
|
||||
"get_rows_i32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
|
||||
"get_rows_q4_0", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
|
||||
"get_rows_q4_1", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
|
||||
"get_rows_q5_0", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
|
||||
"get_rows_q5_1", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
|
||||
"get_rows_q8_0", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
|
||||
"get_rows_q2_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
|
||||
"get_rows_q3_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
|
||||
"get_rows_q4_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
|
||||
"get_rows_q5_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
|
||||
"get_rows_q6_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
|
||||
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
|
||||
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
|
||||
"get_rows_iq2_s", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
|
||||
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
|
||||
"get_rows_iq3_s", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
|
||||
"get_rows_iq1_s", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
|
||||
"get_rows_iq1_m", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
|
||||
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
|
||||
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
|
||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
|
||||
"add_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
|
||||
"add_in_place_f16", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
|
||||
"mul_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
|
||||
"mul_in_place_f16", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
|
||||
"rms_norm_in_place", constants);
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
|
|
@ -1058,24 +1253,77 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
|
|||
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
|
||||
}
|
||||
|
||||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
GGML_UNUSED(dev);
|
||||
static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
|
||||
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
|
||||
|
||||
ggml_tensor * src0 = op->src[0];
|
||||
ggml_tensor * src1 = op->src[1];
|
||||
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
||||
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
||||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
||||
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supports_op = false;
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RESHAPE:
|
||||
return true;
|
||||
supports_op = true;
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MUL:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
|
||||
(op->src[1]->type == op->type);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
||||
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
|
||||
supports_op = (op->type == GGML_TYPE_F32);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
supports_op = (op->src[0]->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
|
|
@ -1099,17 +1347,30 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return true;
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
if (!supports_op) {
|
||||
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
||||
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||
}
|
||||
#endif
|
||||
return supports_op;
|
||||
}
|
||||
|
||||
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
||||
|
|
@ -1155,18 +1416,20 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
ctx->instance.WaitAny(
|
||||
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
ctx->adapter = std::move(adapter);
|
||||
}), UINT64_MAX);
|
||||
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
ctx->adapter = std::move(adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
GGML_ASSERT(ctx->adapter != nullptr);
|
||||
|
||||
ctx->adapter.GetLimits(&ctx->limits);
|
||||
ctx->max_wg_size_x = 288; // default value
|
||||
|
||||
wgpu::AdapterInfo info{};
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
|
@ -1182,21 +1445,21 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
||||
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
std::string(message).c_str());
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR(
|
||||
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
|
||||
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
std::string(message).c_str());
|
||||
});
|
||||
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
||||
&dev_desc,
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
|
||||
std::string(message).c_str());
|
||||
return;
|
||||
}
|
||||
ctx->device = std::move(device);
|
||||
|
|
@ -1208,34 +1471,28 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
ctx->queue = ctx->device.GetQueue();
|
||||
|
||||
// Create buffer pool for shader parameters
|
||||
ctx->param_buf_pool.init(ctx->device,
|
||||
WEBGPU_NUM_PARAM_BUFS,
|
||||
WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
ctx->set_rows_error_buf_pool.init(ctx->device,
|
||||
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_get_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||
ggml_webgpu_init_add_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_pipeline(ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
ggml_webgpu_create_buffer(ctx->device,
|
||||
ctx->debug_host_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
|
||||
"debug_host_buf");
|
||||
ggml_webgpu_create_buffer(ctx->device,
|
||||
ctx->debug_dev_buf,
|
||||
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
|
||||
"debug_dev_buf");
|
||||
ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
|
||||
ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
|
||||
#endif
|
||||
|
||||
static ggml_backend_webgpu_device_context device_ctx;
|
||||
|
|
@ -1246,12 +1503,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
GGML_LOG_INFO(
|
||||
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
||||
"device_desc: %s\n",
|
||||
info.vendorID,
|
||||
std::string(info.vendor).c_str(),
|
||||
std::string(info.architecture).c_str(),
|
||||
info.deviceID,
|
||||
std::string(info.device).c_str(),
|
||||
std::string(info.description).c_str());
|
||||
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
|
||||
std::string(info.device).c_str(), std::string(info.description).c_str());
|
||||
|
||||
// See GGML Backend Device Interface section
|
||||
static ggml_backend_device device = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
struct Params {
|
||||
ne: u32,
|
||||
|
||||
// offsets in elements
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
stride_src1_2: u32,
|
||||
stride_src1_3: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
a_ne1: u32,
|
||||
a_ne2: u32,
|
||||
|
||||
b_ne0: u32,
|
||||
b_ne1: u32,
|
||||
b_ne2: u32,
|
||||
b_ne3: u32,
|
||||
};
|
||||
|
||||
fn src1_index(_i: u32) -> u32 {
|
||||
var i = _i;
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne1 * params.a_ne0);
|
||||
let a_i1 = i / params.a_ne0;
|
||||
let a_i0 = i % params.a_ne0;
|
||||
|
||||
// handle repetition of b
|
||||
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||
let b_i0 = a_i0 % params.b_ne0;
|
||||
let b_i1 = a_i1 % params.b_ne1;
|
||||
let b_i2 = a_i2 % params.b_ne2;
|
||||
let b_i3 = a_i3 % params.b_ne3;
|
||||
|
||||
// compute index for position in b's flat array
|
||||
return b_i0 * params.stride_src1_0 +
|
||||
b_i1 * params.stride_src1_1 +
|
||||
b_i2 * params.stride_src1_2 +
|
||||
b_i3 * params.stride_src1_3;
|
||||
}
|
||||
|
|
@ -0,0 +1,930 @@
|
|||
#decl(BYTE_HELPERS)
|
||||
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
#enddecl(BYTE_HELPERS)
|
||||
|
||||
#decl(Q4_0_T)
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#enddecl(Q4_0_T)
|
||||
|
||||
#decl(Q4_1_T)
|
||||
struct q4_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: array<u32, 4>
|
||||
};
|
||||
#enddecl(Q4_1_T)
|
||||
|
||||
#decl(Q5_0_T)
|
||||
struct q5_0 {
|
||||
d: f16,
|
||||
qh: array<f16, 2>,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#enddecl(Q5_0_T)
|
||||
|
||||
#decl(Q5_1_T)
|
||||
struct q5_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qh: u32,
|
||||
qs: array<u32, 4>
|
||||
};
|
||||
#enddecl(Q5_1_T)
|
||||
|
||||
#decl(Q8_0_T)
|
||||
struct q8_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 16>
|
||||
};
|
||||
#enddecl(Q8_0_T)
|
||||
|
||||
#decl(Q8_1_T)
|
||||
struct q8_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: array<u32, 8>
|
||||
};
|
||||
#enddecl(Q8_1_T)
|
||||
|
||||
#decl(Q2_K_T)
|
||||
struct q2_k {
|
||||
scales: array<u32, 4>,
|
||||
qs: array<u32, 16>,
|
||||
d: f16,
|
||||
dmin: f16
|
||||
};
|
||||
#enddecl(Q2_K_T)
|
||||
|
||||
#decl(Q3_K_T)
|
||||
struct q3_k {
|
||||
hmask: array<f16, 16>,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 6>,
|
||||
d: f16
|
||||
};
|
||||
#enddecl(Q3_K_T)
|
||||
|
||||
#decl(Q45_K_SCALE_MIN)
|
||||
|
||||
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
||||
if (is < 4) {
|
||||
let sc_byte = get_byte(scales[is / 4], is % 4);
|
||||
let min_byte = get_byte(scales[(is + 4) / 4], is % 4);
|
||||
return vec2(f32(sc_byte & 63), f32(min_byte & 63));
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scales[(is + 4) / 4], (is + 4) % 4);
|
||||
let sc_hi = get_byte(scales[(is - 4) / 4], (is - 4) % 4);
|
||||
let min_hi = get_byte(scales[is / 4], is % 4);
|
||||
let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4);
|
||||
let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4);
|
||||
return vec2(f32(sc), f32(m));
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(Q45_K_SCALE_MIN)
|
||||
|
||||
#decl(Q4_K_T)
|
||||
struct q4_k {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: array<u32, 3>,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(Q4_K_T)
|
||||
|
||||
#decl(Q5_K_T)
|
||||
struct q5_k {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: array<u32, 3>,
|
||||
qh: array<u32, 8>,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(Q5_K_T)
|
||||
|
||||
#decl(Q6_K_T)
|
||||
struct q6_k {
|
||||
ql: array<f16, 64>,
|
||||
qh: array<f16, 32>,
|
||||
scales: array<f16, 8>,
|
||||
d: f16
|
||||
};
|
||||
#enddecl(Q6_K_T)
|
||||
|
||||
#decl(IQ2_XXS_T)
|
||||
struct iq2_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>
|
||||
};
|
||||
#enddecl(IQ2_XXS_T)
|
||||
|
||||
#decl(IQ2_XS_T)
|
||||
struct iq2_xs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#enddecl(IQ2_XS_T)
|
||||
|
||||
#decl(IQ2_S_T)
|
||||
struct iq2_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#enddecl(IQ2_S_T)
|
||||
|
||||
#decl(IQ3_XSS_T)
|
||||
struct iq3_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 48>
|
||||
};
|
||||
#enddecl(IQ3_XSS_T)
|
||||
|
||||
#decl(IQ3_S_T)
|
||||
struct iq3_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
signs: array<f16, 16>,
|
||||
scales: array<f16, 2>
|
||||
};
|
||||
#enddecl(IQ3_S_T)
|
||||
|
||||
#decl(IQ1_S_T)
|
||||
struct iq1_s {
|
||||
d: f16,
|
||||
qs: array<f16, 16>,
|
||||
qh: array<f16, 8>
|
||||
};
|
||||
#enddecl(IQ1_S_T)
|
||||
|
||||
#decl(IQ1_M_T)
|
||||
struct iq1_m {
|
||||
qs: array<u32, 8>,
|
||||
qh: array<u32, 4>,
|
||||
scales: array<u32, 2>
|
||||
};
|
||||
#enddecl(IQ1_M_T)
|
||||
|
||||
#decl(IQ4_NL_T)
|
||||
struct iq4_nl {
|
||||
d: f16,
|
||||
qs: array<f16, 8>,
|
||||
};
|
||||
#enddecl(IQ4_NL_T)
|
||||
|
||||
#decl(IQ4_XS_T)
|
||||
struct iq4_xs {
|
||||
d: f16,
|
||||
scales_h: f16,
|
||||
scales_l: u32,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(IQ4_XS_T)
|
||||
|
||||
#decl(IQ23_TABLES)
|
||||
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
|
||||
0x08040201u, // 1, 2, 4, 8
|
||||
0x80402010u // 16, 32, 64, 128
|
||||
);
|
||||
|
||||
const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
|
||||
0x03828100,0x87060584,0x8b0a0988,0x0f8e8d0c,
|
||||
0x93121190,0x17969514,0x1b9a9918,0x9f1e1d9c,
|
||||
0xa32221a0,0x27a6a524,0x2baaa928,0xaf2e2dac,
|
||||
0x33b2b130,0xb73635b4,0xbb3a39b8,0x3fbebd3c,
|
||||
0xc34241c0,0x47c6c544,0x4bcac948,0xcf4e4dcc,
|
||||
0x53d2d150,0xd75655d4,0xdb5a59d8,0x5fdedd5c,
|
||||
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
|
||||
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
|
||||
);
|
||||
#enddecl(IQ23_TABLES)
|
||||
|
||||
#decl(IQ2_XXS_GRID)
|
||||
const iq2xxs_grid = array<u32, 512>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
|
||||
0x082b082b, 0x08080808, 0x082b2b08, 0x08080808, 0x082b2b2b, 0x08080808, 0x19080819, 0x08080808,
|
||||
0x19081908, 0x08080808, 0x19190808, 0x08080808, 0x19192b08, 0x08080808, 0x192b0819, 0x08080808,
|
||||
0x192b1908, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b082b2b, 0x08080808,
|
||||
0x2b2b082b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, 0x08190808, 0x08080819,
|
||||
0x08191919, 0x08080819, 0x19080808, 0x08080819, 0x2b081908, 0x08080819, 0x2b192b08, 0x08080819,
|
||||
0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x082b082b, 0x0808082b, 0x2b08082b, 0x0808082b,
|
||||
0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x08190808, 0x08081908, 0x082b0819, 0x08081908,
|
||||
0x082b1908, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19082b08, 0x08081908,
|
||||
0x192b0808, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908,
|
||||
0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, 0x08082b08, 0x08081919,
|
||||
0x082b0808, 0x08081919, 0x1908192b, 0x08081919, 0x192b2b19, 0x08081919, 0x2b080808, 0x08081919,
|
||||
0x2b190819, 0x08081919, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, 0x19080808, 0x0808192b,
|
||||
0x2b081908, 0x0808192b, 0x2b2b1908, 0x0808192b, 0x08080808, 0x08082b08, 0x08081919, 0x08082b08,
|
||||
0x08082b08, 0x08082b08, 0x08191908, 0x08082b08, 0x082b2b08, 0x08082b08, 0x19080819, 0x08082b08,
|
||||
0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x2b082b08, 0x08082b08,
|
||||
0x08081908, 0x08082b19, 0x19080808, 0x08082b19, 0x0808082b, 0x08082b2b, 0x08191908, 0x08082b2b,
|
||||
0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x08190808, 0x08190808, 0x082b0819, 0x08190808,
|
||||
0x19080808, 0x08190808, 0x192b0808, 0x08190808, 0x2b081908, 0x08190808, 0x2b190808, 0x08190808,
|
||||
0x2b191919, 0x08190808, 0x08080808, 0x08190819, 0x08082b08, 0x08190819, 0x082b0808, 0x08190819,
|
||||
0x19190808, 0x08190819, 0x19192b2b, 0x08190819, 0x2b080808, 0x08190819, 0x082b1908, 0x0819082b,
|
||||
0x19081919, 0x0819082b, 0x08080808, 0x08191908, 0x08082b08, 0x08191908, 0x082b0808, 0x08191908,
|
||||
0x082b1919, 0x08191908, 0x19082b19, 0x08191908, 0x2b080808, 0x08191908, 0x08192b08, 0x08191919,
|
||||
0x192b082b, 0x08191919, 0x08080808, 0x0819192b, 0x0819192b, 0x0819192b, 0x08080819, 0x08192b08,
|
||||
0x08081908, 0x08192b08, 0x08190808, 0x08192b08, 0x19080808, 0x08192b08, 0x2b080819, 0x08192b08,
|
||||
0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x2b2b0808, 0x08192b19, 0x19190819, 0x08192b2b,
|
||||
0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08082b2b, 0x082b0808, 0x19081908, 0x082b0808,
|
||||
0x192b0819, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b08082b, 0x082b0808, 0x082b2b19, 0x082b0819,
|
||||
0x19082b08, 0x082b0819, 0x08080808, 0x082b082b, 0x0808082b, 0x082b082b, 0x08080819, 0x082b1908,
|
||||
0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x19080808, 0x082b1908, 0x1919192b, 0x082b1908,
|
||||
0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x192b1908, 0x082b1919, 0x2b190808, 0x082b192b,
|
||||
0x08082b08, 0x082b2b08, 0x082b0808, 0x082b2b08, 0x2b191908, 0x082b2b08, 0x19081908, 0x082b2b2b,
|
||||
0x08080819, 0x19080808, 0x08081908, 0x19080808, 0x08190808, 0x19080808, 0x08192b08, 0x19080808,
|
||||
0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x19080808, 0x19080808, 0x19082b08, 0x19080808,
|
||||
0x1919192b, 0x19080808, 0x192b0808, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808,
|
||||
0x2b190808, 0x19080808, 0x08080808, 0x19080819, 0x082b0808, 0x19080819, 0x192b0819, 0x19080819,
|
||||
0x2b080808, 0x19080819, 0x2b081919, 0x19080819, 0x08080819, 0x1908082b, 0x08190808, 0x1908082b,
|
||||
0x19082b08, 0x1908082b, 0x1919192b, 0x1908082b, 0x192b2b08, 0x1908082b, 0x08080808, 0x19081908,
|
||||
0x08082b08, 0x19081908, 0x082b0808, 0x19081908, 0x2b080808, 0x19081908, 0x2b192b19, 0x19081908,
|
||||
0x0819082b, 0x19081919, 0x082b1908, 0x19081919, 0x08080808, 0x1908192b, 0x08080819, 0x19082b08,
|
||||
0x08081908, 0x19082b08, 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08,
|
||||
0x08080808, 0x19082b19, 0x19192b08, 0x19082b19, 0x192b0819, 0x19082b19, 0x2b08082b, 0x19082b19,
|
||||
0x19081919, 0x19082b2b, 0x2b190808, 0x19082b2b, 0x08080808, 0x19190808, 0x08082b08, 0x19190808,
|
||||
0x08190819, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x2b080808, 0x19190808,
|
||||
0x2b082b08, 0x19190808, 0x08081908, 0x19190819, 0x1908082b, 0x19190819, 0x2b2b1908, 0x19190819,
|
||||
0x2b190819, 0x1919082b, 0x2b190808, 0x19191908, 0x2b19082b, 0x19191908, 0x08082b2b, 0x19191919,
|
||||
0x08080819, 0x1919192b, 0x19191908, 0x1919192b, 0x08080808, 0x19192b08, 0x08190819, 0x19192b08,
|
||||
0x08192b19, 0x19192b08, 0x192b1908, 0x19192b08, 0x19080808, 0x19192b19, 0x08082b08, 0x19192b2b,
|
||||
0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, 0x192b2b08, 0x192b0808,
|
||||
0x08080808, 0x192b0819, 0x19191919, 0x192b0819, 0x08192b08, 0x192b082b, 0x192b0808, 0x192b082b,
|
||||
0x08080808, 0x192b1908, 0x08081919, 0x192b1908, 0x08190808, 0x192b1919, 0x0819082b, 0x192b1919,
|
||||
0x2b081908, 0x192b1919, 0x1908082b, 0x192b2b08, 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808,
|
||||
0x08082b2b, 0x2b080808, 0x19080819, 0x2b080808, 0x2b08082b, 0x2b080808, 0x08081908, 0x2b080819,
|
||||
0x08192b08, 0x2b080819, 0x19080808, 0x2b080819, 0x08190819, 0x2b08082b, 0x08080819, 0x2b081908,
|
||||
0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x08191919, 0x2b081908, 0x19080808, 0x2b081908,
|
||||
0x192b0808, 0x2b081908, 0x08080808, 0x2b081919, 0x1908192b, 0x2b081919, 0x2b191908, 0x2b081919,
|
||||
0x08082b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x192b0808, 0x2b08192b, 0x0808082b, 0x2b082b08,
|
||||
0x08081908, 0x2b082b19, 0x08190819, 0x2b082b2b, 0x08081908, 0x2b190808, 0x08190808, 0x2b190808,
|
||||
0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, 0x2b2b0819, 0x2b190808, 0x0819192b, 0x2b190819,
|
||||
0x2b080808, 0x2b190819, 0x19081919, 0x2b19082b, 0x08080808, 0x2b191908, 0x082b082b, 0x2b191908,
|
||||
0x19081908, 0x2b191908, 0x19190819, 0x2b191919, 0x2b080819, 0x2b192b08, 0x082b0808, 0x2b192b19,
|
||||
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
|
||||
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
|
||||
);
|
||||
#enddecl(IQ2_XXS_GRID)
|
||||
|
||||
#decl(IQ2_XS_GRID)
|
||||
const iq2xs_grid = array<u32, 1024>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
|
||||
0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808,
|
||||
0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808,
|
||||
0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808,
|
||||
0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x2b080808, 0x08080808,
|
||||
0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808,
|
||||
0x2b191908, 0x08080808, 0x2b192b19, 0x08080808, 0x2b2b0808, 0x08080808, 0x08080819, 0x08080819,
|
||||
0x08081908, 0x08080819, 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819,
|
||||
0x0819082b, 0x08080819, 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x08192b2b, 0x08080819,
|
||||
0x082b0819, 0x08080819, 0x082b1908, 0x08080819, 0x19080808, 0x08080819, 0x1908082b, 0x08080819,
|
||||
0x19081919, 0x08080819, 0x19082b08, 0x08080819, 0x19190819, 0x08080819, 0x19191908, 0x08080819,
|
||||
0x192b0808, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, 0x2b081908, 0x08080819,
|
||||
0x2b190808, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x08081919, 0x0808082b,
|
||||
0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, 0x082b0808, 0x0808082b,
|
||||
0x19080819, 0x0808082b, 0x19081908, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b,
|
||||
0x2b080808, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908,
|
||||
0x0808192b, 0x08081908, 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908,
|
||||
0x08191919, 0x08081908, 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908,
|
||||
0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, 0x19082b08, 0x08081908,
|
||||
0x19190819, 0x08081908, 0x19191908, 0x08081908, 0x1919192b, 0x08081908, 0x192b0808, 0x08081908,
|
||||
0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, 0x08080808, 0x08081919,
|
||||
0x0808082b, 0x08081919, 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08190819, 0x08081919,
|
||||
0x08191908, 0x08081919, 0x082b0808, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919,
|
||||
0x19190808, 0x08081919, 0x192b0819, 0x08081919, 0x2b080808, 0x08081919, 0x08080819, 0x0808192b,
|
||||
0x08081908, 0x0808192b, 0x08190808, 0x0808192b, 0x082b192b, 0x0808192b, 0x19080808, 0x0808192b,
|
||||
0x1908082b, 0x0808192b, 0x2b081908, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08,
|
||||
0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08082b2b, 0x08082b08, 0x08190819, 0x08082b08,
|
||||
0x08191908, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, 0x19080819, 0x08082b08,
|
||||
0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x19192b08, 0x08082b08, 0x2b080808, 0x08082b08,
|
||||
0x2b2b0808, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, 0x08081908, 0x08082b19,
|
||||
0x08190808, 0x08082b19, 0x19080808, 0x08082b19, 0x2b080819, 0x08082b19, 0x2b082b19, 0x08082b19,
|
||||
0x08080808, 0x08082b2b, 0x082b0808, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x2b19192b, 0x08082b2b,
|
||||
0x2b2b0808, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x0808192b, 0x08190808,
|
||||
0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, 0x08191919, 0x08190808,
|
||||
0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, 0x19080808, 0x08190808,
|
||||
0x1908082b, 0x08190808, 0x19081919, 0x08190808, 0x19082b08, 0x08190808, 0x19190819, 0x08190808,
|
||||
0x19191908, 0x08190808, 0x192b0808, 0x08190808, 0x192b2b2b, 0x08190808, 0x2b080819, 0x08190808,
|
||||
0x2b081908, 0x08190808, 0x2b190808, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819,
|
||||
0x08081919, 0x08190819, 0x08082b08, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819,
|
||||
0x082b0808, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, 0x19190808, 0x08190819,
|
||||
0x2b080808, 0x08190819, 0x2b191908, 0x08190819, 0x2b19192b, 0x08190819, 0x08080819, 0x0819082b,
|
||||
0x08081908, 0x0819082b, 0x0808192b, 0x0819082b, 0x08190808, 0x0819082b, 0x19080808, 0x0819082b,
|
||||
0x192b0808, 0x0819082b, 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908,
|
||||
0x08082b08, 0x08191908, 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x082b0808, 0x08191908,
|
||||
0x19080819, 0x08191908, 0x19081908, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908,
|
||||
0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x08080819, 0x08191919, 0x08081908, 0x08191919,
|
||||
0x08190808, 0x08191919, 0x19080808, 0x08191919, 0x08080808, 0x0819192b, 0x08191908, 0x0819192b,
|
||||
0x19082b19, 0x0819192b, 0x08080819, 0x08192b08, 0x08081908, 0x08192b08, 0x08190808, 0x08192b08,
|
||||
0x0819082b, 0x08192b08, 0x19080808, 0x08192b08, 0x19191908, 0x08192b08, 0x2b08192b, 0x08192b08,
|
||||
0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x192b192b, 0x08192b19, 0x19190819, 0x08192b2b,
|
||||
0x2b2b2b19, 0x08192b2b, 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808,
|
||||
0x08082b08, 0x082b0808, 0x08082b2b, 0x082b0808, 0x08190819, 0x082b0808, 0x08191908, 0x082b0808,
|
||||
0x082b0808, 0x082b0808, 0x19080819, 0x082b0808, 0x19081908, 0x082b0808, 0x19190808, 0x082b0808,
|
||||
0x2b080808, 0x082b0808, 0x2b2b0808, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819,
|
||||
0x08190808, 0x082b0819, 0x19080808, 0x082b0819, 0x19082b08, 0x082b0819, 0x192b1919, 0x082b0819,
|
||||
0x08080808, 0x082b082b, 0x082b082b, 0x082b082b, 0x2b080808, 0x082b082b, 0x2b2b2b08, 0x082b082b,
|
||||
0x08080819, 0x082b1908, 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x082b2b19, 0x082b1908,
|
||||
0x19080808, 0x082b1908, 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x1919082b, 0x082b1919,
|
||||
0x2b192b19, 0x082b1919, 0x08080819, 0x082b192b, 0x08192b2b, 0x082b192b, 0x2b2b192b, 0x082b192b,
|
||||
0x08080808, 0x082b2b08, 0x08082b08, 0x082b2b08, 0x08082b2b, 0x082b2b08, 0x082b0808, 0x082b2b08,
|
||||
0x19191919, 0x082b2b08, 0x2b082b08, 0x082b2b08, 0x2b2b082b, 0x082b2b08, 0x192b2b08, 0x082b2b19,
|
||||
0x2b190808, 0x082b2b19, 0x08082b08, 0x082b2b2b, 0x082b0808, 0x082b2b2b, 0x2b08082b, 0x082b2b2b,
|
||||
0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, 0x08081908, 0x19080808,
|
||||
0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, 0x0819082b, 0x19080808,
|
||||
0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x082b0819, 0x19080808, 0x082b1908, 0x19080808,
|
||||
0x19080808, 0x19080808, 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808,
|
||||
0x19082b2b, 0x19080808, 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x192b0808, 0x19080808,
|
||||
0x192b1919, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, 0x2b190808, 0x19080808,
|
||||
0x08080808, 0x19080819, 0x0808082b, 0x19080819, 0x08081919, 0x19080819, 0x08082b08, 0x19080819,
|
||||
0x08190819, 0x19080819, 0x08191908, 0x19080819, 0x082b0808, 0x19080819, 0x19080819, 0x19080819,
|
||||
0x19081908, 0x19080819, 0x19190808, 0x19080819, 0x2b080808, 0x19080819, 0x2b081919, 0x19080819,
|
||||
0x2b2b082b, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, 0x08190808, 0x1908082b,
|
||||
0x0819082b, 0x1908082b, 0x082b2b19, 0x1908082b, 0x19080808, 0x1908082b, 0x08080808, 0x19081908,
|
||||
0x0808082b, 0x19081908, 0x08081919, 0x19081908, 0x08082b08, 0x19081908, 0x08190819, 0x19081908,
|
||||
0x08191908, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x19080819, 0x19081908,
|
||||
0x19081908, 0x19081908, 0x19190808, 0x19081908, 0x2b080808, 0x19081908, 0x2b191908, 0x19081908,
|
||||
0x08080819, 0x19081919, 0x08081908, 0x19081919, 0x08190808, 0x19081919, 0x082b1908, 0x19081919,
|
||||
0x19080808, 0x19081919, 0x2b192b2b, 0x19081919, 0x08080808, 0x1908192b, 0x08082b2b, 0x1908192b,
|
||||
0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x08080819, 0x19082b08, 0x08081908, 0x19082b08,
|
||||
0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, 0x19191908, 0x19082b08,
|
||||
0x192b082b, 0x19082b08, 0x08080808, 0x19082b19, 0x08190819, 0x19082b19, 0x19081908, 0x19082b19,
|
||||
0x19190808, 0x19082b19, 0x192b2b19, 0x19082b19, 0x08081908, 0x19082b2b, 0x08080808, 0x19190808,
|
||||
0x0808082b, 0x19190808, 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808,
|
||||
0x08191908, 0x19190808, 0x082b0808, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808,
|
||||
0x19081908, 0x19190808, 0x19190808, 0x19190808, 0x2b080808, 0x19190808, 0x08080819, 0x19190819,
|
||||
0x08081908, 0x19190819, 0x08190808, 0x19190819, 0x08191919, 0x19190819, 0x19080808, 0x19190819,
|
||||
0x1908082b, 0x19190819, 0x08080808, 0x1919082b, 0x19081908, 0x1919082b, 0x2b2b2b2b, 0x1919082b,
|
||||
0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x08190808, 0x19191908, 0x082b0819, 0x19191908,
|
||||
0x19080808, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b2b0819, 0x19191908,
|
||||
0x08080808, 0x19191919, 0x08082b08, 0x19191919, 0x2b080808, 0x19191919, 0x2b082b08, 0x19191919,
|
||||
0x082b0819, 0x1919192b, 0x192b2b08, 0x1919192b, 0x2b2b0819, 0x1919192b, 0x08080808, 0x19192b08,
|
||||
0x08191908, 0x19192b08, 0x19080819, 0x19192b08, 0x19190808, 0x19192b08, 0x2b192b19, 0x19192b08,
|
||||
0x08192b2b, 0x19192b19, 0x19080808, 0x19192b19, 0x1908082b, 0x19192b19, 0x2b081919, 0x19192b2b,
|
||||
0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808,
|
||||
0x19191908, 0x192b0808, 0x192b082b, 0x192b0808, 0x2b08192b, 0x192b0808, 0x2b2b2b19, 0x192b0808,
|
||||
0x08080808, 0x192b0819, 0x082b1908, 0x192b082b, 0x19082b2b, 0x192b082b, 0x2b19082b, 0x192b082b,
|
||||
0x08080808, 0x192b1908, 0x0819192b, 0x192b1908, 0x08190808, 0x192b1919, 0x19080808, 0x192b1919,
|
||||
0x19081919, 0x192b1919, 0x2b2b1908, 0x192b1919, 0x08080819, 0x192b2b08, 0x192b2b2b, 0x192b2b08,
|
||||
0x082b1919, 0x192b2b19, 0x0808192b, 0x192b2b2b, 0x19191908, 0x192b2b2b, 0x192b082b, 0x192b2b2b,
|
||||
0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808,
|
||||
0x08190819, 0x2b080808, 0x08191908, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b2b2b, 0x2b080808,
|
||||
0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x2b080808, 0x2b080808,
|
||||
0x2b08082b, 0x2b080808, 0x2b2b2b08, 0x2b080808, 0x2b2b2b2b, 0x2b080808, 0x08080819, 0x2b080819,
|
||||
0x08081908, 0x2b080819, 0x0808192b, 0x2b080819, 0x08190808, 0x2b080819, 0x19080808, 0x2b080819,
|
||||
0x19190819, 0x2b080819, 0x19192b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x082b0808, 0x2b08082b,
|
||||
0x2b080808, 0x2b08082b, 0x2b08082b, 0x2b08082b, 0x2b2b0808, 0x2b08082b, 0x2b2b2b08, 0x2b08082b,
|
||||
0x08080819, 0x2b081908, 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908,
|
||||
0x08191919, 0x2b081908, 0x19080808, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b082b19, 0x2b081908,
|
||||
0x08080808, 0x2b081919, 0x19081908, 0x2b081919, 0x2b2b1919, 0x2b081919, 0x08192b08, 0x2b08192b,
|
||||
0x192b2b2b, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08082b08, 0x2b082b08, 0x082b1919, 0x2b082b08,
|
||||
0x19192b2b, 0x2b082b08, 0x2b080808, 0x2b082b08, 0x2b08082b, 0x2b082b08, 0x2b2b2b08, 0x2b082b08,
|
||||
0x0808192b, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x2b080808, 0x2b082b2b, 0x2b082b08, 0x2b082b2b,
|
||||
0x2b19192b, 0x2b082b2b, 0x2b2b2b08, 0x2b082b2b, 0x08080819, 0x2b190808, 0x08081908, 0x2b190808,
|
||||
0x08190808, 0x2b190808, 0x19080808, 0x2b190808, 0x1919192b, 0x2b190808, 0x2b081908, 0x2b190808,
|
||||
0x08080808, 0x2b190819, 0x082b082b, 0x2b190819, 0x192b1908, 0x2b190819, 0x1919192b, 0x2b19082b,
|
||||
0x2b082b19, 0x2b19082b, 0x08080808, 0x2b191908, 0x08081919, 0x2b191908, 0x19081908, 0x2b191908,
|
||||
0x19190808, 0x2b191908, 0x19192b08, 0x2b191908, 0x082b2b19, 0x2b191919, 0x2b190808, 0x2b191919,
|
||||
0x2b19082b, 0x2b191919, 0x19080819, 0x2b19192b, 0x19190819, 0x2b192b08, 0x2b2b192b, 0x2b192b08,
|
||||
0x19082b19, 0x2b192b19, 0x08191919, 0x2b192b2b, 0x192b0808, 0x2b192b2b, 0x08080808, 0x2b2b0808,
|
||||
0x0808082b, 0x2b2b0808, 0x08082b08, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, 0x082b0808, 0x2b2b0808,
|
||||
0x082b2b2b, 0x2b2b0808, 0x2b2b0808, 0x2b2b0808, 0x19190819, 0x2b2b0819, 0x19192b19, 0x2b2b0819,
|
||||
0x2b2b192b, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x0808082b, 0x2b2b082b, 0x08082b08, 0x2b2b082b,
|
||||
0x082b2b2b, 0x2b2b082b, 0x2b080808, 0x2b2b082b, 0x2b2b0808, 0x2b2b082b, 0x19080808, 0x2b2b1908,
|
||||
0x2b191919, 0x2b2b1908, 0x192b1919, 0x2b2b192b, 0x2b192b08, 0x2b2b192b, 0x08082b2b, 0x2b2b2b08,
|
||||
0x082b0808, 0x2b2b2b08, 0x082b082b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b0808, 0x2b2b2b08,
|
||||
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
|
||||
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
|
||||
);
|
||||
#enddecl(IQ2_XS_GRID)
|
||||
|
||||
#decl(IQ2_S_GRID)
|
||||
const iq2s_grid = array<u32, 2048>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
|
||||
0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808,
|
||||
0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808,
|
||||
0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808,
|
||||
0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x192b192b, 0x08080808,
|
||||
0x192b2b19, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808,
|
||||
0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, 0x2b191908, 0x08080808, 0x2b2b0808, 0x08080808,
|
||||
0x2b2b1919, 0x08080808, 0x2b2b2b2b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819,
|
||||
0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, 0x0819082b, 0x08080819,
|
||||
0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x082b0819, 0x08080819, 0x082b1908, 0x08080819,
|
||||
0x19080808, 0x08080819, 0x1908082b, 0x08080819, 0x19081919, 0x08080819, 0x19082b08, 0x08080819,
|
||||
0x19190819, 0x08080819, 0x19191908, 0x08080819, 0x1919192b, 0x08080819, 0x19192b19, 0x08080819,
|
||||
0x192b0808, 0x08080819, 0x192b1919, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819,
|
||||
0x2b081908, 0x08080819, 0x2b190808, 0x08080819, 0x2b19082b, 0x08080819, 0x2b191919, 0x08080819,
|
||||
0x2b2b0819, 0x08080819, 0x2b2b1908, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b,
|
||||
0x08081919, 0x0808082b, 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b,
|
||||
0x082b0808, 0x0808082b, 0x082b2b2b, 0x0808082b, 0x19080819, 0x0808082b, 0x19081908, 0x0808082b,
|
||||
0x1908192b, 0x0808082b, 0x19082b19, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b,
|
||||
0x2b080808, 0x0808082b, 0x2b081919, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x2b191908, 0x0808082b,
|
||||
0x2b2b082b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x0808192b, 0x08081908,
|
||||
0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, 0x08191919, 0x08081908,
|
||||
0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, 0x082b192b, 0x08081908,
|
||||
0x082b2b19, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908,
|
||||
0x19082b08, 0x08081908, 0x19082b2b, 0x08081908, 0x19190819, 0x08081908, 0x19191908, 0x08081908,
|
||||
0x1919192b, 0x08081908, 0x19192b19, 0x08081908, 0x192b0808, 0x08081908, 0x192b082b, 0x08081908,
|
||||
0x192b1919, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b08192b, 0x08081908,
|
||||
0x2b082b19, 0x08081908, 0x2b190808, 0x08081908, 0x2b191919, 0x08081908, 0x2b192b08, 0x08081908,
|
||||
0x2b2b0819, 0x08081908, 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919,
|
||||
0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08082b2b, 0x08081919, 0x08190819, 0x08081919,
|
||||
0x08191908, 0x08081919, 0x0819192b, 0x08081919, 0x08192b19, 0x08081919, 0x082b0808, 0x08081919,
|
||||
0x082b1919, 0x08081919, 0x082b2b08, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919,
|
||||
0x1908192b, 0x08081919, 0x19082b19, 0x08081919, 0x19190808, 0x08081919, 0x1919082b, 0x08081919,
|
||||
0x19191919, 0x08081919, 0x19192b08, 0x08081919, 0x192b0819, 0x08081919, 0x192b1908, 0x08081919,
|
||||
0x2b080808, 0x08081919, 0x2b08082b, 0x08081919, 0x2b081919, 0x08081919, 0x2b082b08, 0x08081919,
|
||||
0x2b190819, 0x08081919, 0x2b191908, 0x08081919, 0x2b2b0808, 0x08081919, 0x08080819, 0x0808192b,
|
||||
0x08081908, 0x0808192b, 0x0808192b, 0x0808192b, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b,
|
||||
0x08191919, 0x0808192b, 0x19080808, 0x0808192b, 0x19081919, 0x0808192b, 0x19082b08, 0x0808192b,
|
||||
0x19190819, 0x0808192b, 0x19191908, 0x0808192b, 0x192b0808, 0x0808192b, 0x2b080819, 0x0808192b,
|
||||
0x2b081908, 0x0808192b, 0x2b190808, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08,
|
||||
0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08190819, 0x08082b08, 0x08191908, 0x08082b08,
|
||||
0x0819192b, 0x08082b08, 0x08192b19, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08,
|
||||
0x082b2b2b, 0x08082b08, 0x19080819, 0x08082b08, 0x19081908, 0x08082b08, 0x1908192b, 0x08082b08,
|
||||
0x19082b19, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x19191919, 0x08082b08,
|
||||
0x19192b08, 0x08082b08, 0x192b0819, 0x08082b08, 0x192b1908, 0x08082b08, 0x2b080808, 0x08082b08,
|
||||
0x2b081919, 0x08082b08, 0x2b191908, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19,
|
||||
0x08081908, 0x08082b19, 0x08190808, 0x08082b19, 0x0819082b, 0x08082b19, 0x08191919, 0x08082b19,
|
||||
0x08192b08, 0x08082b19, 0x082b0819, 0x08082b19, 0x19080808, 0x08082b19, 0x19081919, 0x08082b19,
|
||||
0x19082b08, 0x08082b19, 0x19190819, 0x08082b19, 0x19191908, 0x08082b19, 0x192b0808, 0x08082b19,
|
||||
0x2b080819, 0x08082b19, 0x2b190808, 0x08082b19, 0x08080808, 0x08082b2b, 0x08190819, 0x08082b2b,
|
||||
0x08191908, 0x08082b2b, 0x082b082b, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x082b2b2b, 0x08082b2b,
|
||||
0x19190808, 0x08082b2b, 0x2b192b19, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808,
|
||||
0x0808192b, 0x08190808, 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808,
|
||||
0x08191919, 0x08190808, 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808,
|
||||
0x082b192b, 0x08190808, 0x19080808, 0x08190808, 0x1908082b, 0x08190808, 0x19081919, 0x08190808,
|
||||
0x19082b08, 0x08190808, 0x19190819, 0x08190808, 0x19191908, 0x08190808, 0x1919192b, 0x08190808,
|
||||
0x19192b19, 0x08190808, 0x192b0808, 0x08190808, 0x192b082b, 0x08190808, 0x192b1919, 0x08190808,
|
||||
0x192b2b08, 0x08190808, 0x2b080819, 0x08190808, 0x2b081908, 0x08190808, 0x2b08192b, 0x08190808,
|
||||
0x2b190808, 0x08190808, 0x2b191919, 0x08190808, 0x2b192b08, 0x08190808, 0x2b2b0819, 0x08190808,
|
||||
0x2b2b1908, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, 0x08081919, 0x08190819,
|
||||
0x08082b08, 0x08190819, 0x08082b2b, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819,
|
||||
0x0819192b, 0x08190819, 0x08192b19, 0x08190819, 0x082b0808, 0x08190819, 0x082b082b, 0x08190819,
|
||||
0x082b1919, 0x08190819, 0x082b2b08, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819,
|
||||
0x1908192b, 0x08190819, 0x19082b19, 0x08190819, 0x19190808, 0x08190819, 0x1919082b, 0x08190819,
|
||||
0x19191919, 0x08190819, 0x19192b08, 0x08190819, 0x192b0819, 0x08190819, 0x192b1908, 0x08190819,
|
||||
0x2b080808, 0x08190819, 0x2b08082b, 0x08190819, 0x2b081919, 0x08190819, 0x2b082b08, 0x08190819,
|
||||
0x2b190819, 0x08190819, 0x2b191908, 0x08190819, 0x08080819, 0x0819082b, 0x08081908, 0x0819082b,
|
||||
0x08082b19, 0x0819082b, 0x08190808, 0x0819082b, 0x08191919, 0x0819082b, 0x082b0819, 0x0819082b,
|
||||
0x082b1908, 0x0819082b, 0x19080808, 0x0819082b, 0x19081919, 0x0819082b, 0x19190819, 0x0819082b,
|
||||
0x19191908, 0x0819082b, 0x2b080819, 0x0819082b, 0x2b081908, 0x0819082b, 0x2b190808, 0x0819082b,
|
||||
0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, 0x08082b08, 0x08191908,
|
||||
0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x0819192b, 0x08191908, 0x08192b19, 0x08191908,
|
||||
0x082b0808, 0x08191908, 0x082b1919, 0x08191908, 0x082b2b08, 0x08191908, 0x19080819, 0x08191908,
|
||||
0x19081908, 0x08191908, 0x1908192b, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908,
|
||||
0x1919082b, 0x08191908, 0x19191919, 0x08191908, 0x19192b08, 0x08191908, 0x192b0819, 0x08191908,
|
||||
0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x2b08082b, 0x08191908, 0x2b081919, 0x08191908,
|
||||
0x2b082b08, 0x08191908, 0x2b190819, 0x08191908, 0x2b191908, 0x08191908, 0x2b2b0808, 0x08191908,
|
||||
0x08080819, 0x08191919, 0x08081908, 0x08191919, 0x0808192b, 0x08191919, 0x08082b19, 0x08191919,
|
||||
0x08190808, 0x08191919, 0x0819082b, 0x08191919, 0x08191919, 0x08191919, 0x08192b08, 0x08191919,
|
||||
0x082b0819, 0x08191919, 0x082b1908, 0x08191919, 0x19080808, 0x08191919, 0x1908082b, 0x08191919,
|
||||
0x19081919, 0x08191919, 0x19082b08, 0x08191919, 0x19190819, 0x08191919, 0x19191908, 0x08191919,
|
||||
0x192b0808, 0x08191919, 0x2b080819, 0x08191919, 0x2b081908, 0x08191919, 0x2b190808, 0x08191919,
|
||||
0x08080808, 0x0819192b, 0x08081919, 0x0819192b, 0x08082b08, 0x0819192b, 0x08190819, 0x0819192b,
|
||||
0x08191908, 0x0819192b, 0x082b0808, 0x0819192b, 0x19080819, 0x0819192b, 0x19081908, 0x0819192b,
|
||||
0x19190808, 0x0819192b, 0x2b080808, 0x0819192b, 0x2b2b2b2b, 0x0819192b, 0x08080819, 0x08192b08,
|
||||
0x08081908, 0x08192b08, 0x0808192b, 0x08192b08, 0x08082b19, 0x08192b08, 0x08190808, 0x08192b08,
|
||||
0x08191919, 0x08192b08, 0x08192b08, 0x08192b08, 0x082b0819, 0x08192b08, 0x19080808, 0x08192b08,
|
||||
0x1908082b, 0x08192b08, 0x19081919, 0x08192b08, 0x19082b08, 0x08192b08, 0x19190819, 0x08192b08,
|
||||
0x19191908, 0x08192b08, 0x192b0808, 0x08192b08, 0x2b080819, 0x08192b08, 0x2b081908, 0x08192b08,
|
||||
0x08080808, 0x08192b19, 0x0808082b, 0x08192b19, 0x08081919, 0x08192b19, 0x08082b08, 0x08192b19,
|
||||
0x08190819, 0x08192b19, 0x08191908, 0x08192b19, 0x082b0808, 0x08192b19, 0x19080819, 0x08192b19,
|
||||
0x19081908, 0x08192b19, 0x19190808, 0x08192b19, 0x192b2b19, 0x08192b19, 0x2b2b082b, 0x08192b19,
|
||||
0x08081908, 0x08192b2b, 0x08190808, 0x08192b2b, 0x19080808, 0x08192b2b, 0x1919192b, 0x08192b2b,
|
||||
0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, 0x08082b08, 0x082b0808,
|
||||
0x08190819, 0x082b0808, 0x08191908, 0x082b0808, 0x0819192b, 0x082b0808, 0x08192b19, 0x082b0808,
|
||||
0x082b0808, 0x082b0808, 0x082b1919, 0x082b0808, 0x082b2b2b, 0x082b0808, 0x19080819, 0x082b0808,
|
||||
0x19081908, 0x082b0808, 0x19190808, 0x082b0808, 0x1919082b, 0x082b0808, 0x19191919, 0x082b0808,
|
||||
0x192b1908, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b082b2b, 0x082b0808, 0x2b191908, 0x082b0808,
|
||||
0x2b2b2b2b, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, 0x08190808, 0x082b0819,
|
||||
0x0819082b, 0x082b0819, 0x08191919, 0x082b0819, 0x082b0819, 0x082b0819, 0x19080808, 0x082b0819,
|
||||
0x1908082b, 0x082b0819, 0x19081919, 0x082b0819, 0x19190819, 0x082b0819, 0x19191908, 0x082b0819,
|
||||
0x192b0808, 0x082b0819, 0x2b080819, 0x082b0819, 0x2b081908, 0x082b0819, 0x2b190808, 0x082b0819,
|
||||
0x08080808, 0x082b082b, 0x08082b2b, 0x082b082b, 0x082b082b, 0x082b082b, 0x082b2b08, 0x082b082b,
|
||||
0x082b2b2b, 0x082b082b, 0x19081908, 0x082b082b, 0x19190808, 0x082b082b, 0x2b082b08, 0x082b082b,
|
||||
0x2b082b2b, 0x082b082b, 0x2b2b2b08, 0x082b082b, 0x08080819, 0x082b1908, 0x08081908, 0x082b1908,
|
||||
0x0808192b, 0x082b1908, 0x08082b19, 0x082b1908, 0x08190808, 0x082b1908, 0x08191919, 0x082b1908,
|
||||
0x08192b08, 0x082b1908, 0x082b0819, 0x082b1908, 0x082b1908, 0x082b1908, 0x19080808, 0x082b1908,
|
||||
0x1908082b, 0x082b1908, 0x19081919, 0x082b1908, 0x19082b08, 0x082b1908, 0x19190819, 0x082b1908,
|
||||
0x19191908, 0x082b1908, 0x192b0808, 0x082b1908, 0x2b080819, 0x082b1908, 0x2b081908, 0x082b1908,
|
||||
0x2b190808, 0x082b1908, 0x08080808, 0x082b1919, 0x08081919, 0x082b1919, 0x08082b08, 0x082b1919,
|
||||
0x08190819, 0x082b1919, 0x08191908, 0x082b1919, 0x082b0808, 0x082b1919, 0x19080819, 0x082b1919,
|
||||
0x19081908, 0x082b1919, 0x19190808, 0x082b1919, 0x192b192b, 0x082b1919, 0x2b080808, 0x082b1919,
|
||||
0x08080819, 0x082b192b, 0x08081908, 0x082b192b, 0x08190808, 0x082b192b, 0x19080808, 0x082b192b,
|
||||
0x19192b19, 0x082b192b, 0x08080808, 0x082b2b08, 0x08081919, 0x082b2b08, 0x08190819, 0x082b2b08,
|
||||
0x08191908, 0x082b2b08, 0x19080819, 0x082b2b08, 0x19081908, 0x082b2b08, 0x19190808, 0x082b2b08,
|
||||
0x2b082b2b, 0x082b2b08, 0x2b2b2b2b, 0x082b2b08, 0x08080819, 0x082b2b19, 0x08081908, 0x082b2b19,
|
||||
0x08190808, 0x082b2b19, 0x2b191919, 0x082b2b19, 0x08082b2b, 0x082b2b2b, 0x082b082b, 0x082b2b2b,
|
||||
0x192b1908, 0x082b2b2b, 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808,
|
||||
0x08081908, 0x19080808, 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808,
|
||||
0x0819082b, 0x19080808, 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x08192b2b, 0x19080808,
|
||||
0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x082b192b, 0x19080808, 0x19080808, 0x19080808,
|
||||
0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, 0x19082b2b, 0x19080808,
|
||||
0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x1919192b, 0x19080808, 0x19192b19, 0x19080808,
|
||||
0x192b0808, 0x19080808, 0x192b082b, 0x19080808, 0x192b1919, 0x19080808, 0x2b080819, 0x19080808,
|
||||
0x2b081908, 0x19080808, 0x2b190808, 0x19080808, 0x2b191919, 0x19080808, 0x2b192b08, 0x19080808,
|
||||
0x2b2b0819, 0x19080808, 0x2b2b1908, 0x19080808, 0x08080808, 0x19080819, 0x0808082b, 0x19080819,
|
||||
0x08081919, 0x19080819, 0x08082b08, 0x19080819, 0x08190819, 0x19080819, 0x08191908, 0x19080819,
|
||||
0x0819192b, 0x19080819, 0x08192b19, 0x19080819, 0x082b0808, 0x19080819, 0x082b082b, 0x19080819,
|
||||
0x082b1919, 0x19080819, 0x19080819, 0x19080819, 0x19081908, 0x19080819, 0x1908192b, 0x19080819,
|
||||
0x19082b19, 0x19080819, 0x19190808, 0x19080819, 0x1919082b, 0x19080819, 0x19191919, 0x19080819,
|
||||
0x19192b08, 0x19080819, 0x192b0819, 0x19080819, 0x192b1908, 0x19080819, 0x2b080808, 0x19080819,
|
||||
0x2b08082b, 0x19080819, 0x2b081919, 0x19080819, 0x2b082b08, 0x19080819, 0x2b190819, 0x19080819,
|
||||
0x2b191908, 0x19080819, 0x2b2b0808, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b,
|
||||
0x08190808, 0x1908082b, 0x0819082b, 0x1908082b, 0x08191919, 0x1908082b, 0x08192b08, 0x1908082b,
|
||||
0x082b1908, 0x1908082b, 0x19080808, 0x1908082b, 0x19081919, 0x1908082b, 0x19082b08, 0x1908082b,
|
||||
0x19190819, 0x1908082b, 0x19191908, 0x1908082b, 0x192b0808, 0x1908082b, 0x2b080819, 0x1908082b,
|
||||
0x2b081908, 0x1908082b, 0x08080808, 0x19081908, 0x0808082b, 0x19081908, 0x08081919, 0x19081908,
|
||||
0x08082b08, 0x19081908, 0x08082b2b, 0x19081908, 0x08190819, 0x19081908, 0x08191908, 0x19081908,
|
||||
0x0819192b, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x082b082b, 0x19081908,
|
||||
0x082b1919, 0x19081908, 0x082b2b08, 0x19081908, 0x19080819, 0x19081908, 0x19081908, 0x19081908,
|
||||
0x1908192b, 0x19081908, 0x19082b19, 0x19081908, 0x19190808, 0x19081908, 0x1919082b, 0x19081908,
|
||||
0x19191919, 0x19081908, 0x19192b08, 0x19081908, 0x192b0819, 0x19081908, 0x192b1908, 0x19081908,
|
||||
0x2b080808, 0x19081908, 0x2b08082b, 0x19081908, 0x2b081919, 0x19081908, 0x2b082b08, 0x19081908,
|
||||
0x2b190819, 0x19081908, 0x2b191908, 0x19081908, 0x2b2b0808, 0x19081908, 0x08080819, 0x19081919,
|
||||
0x08081908, 0x19081919, 0x0808192b, 0x19081919, 0x08082b19, 0x19081919, 0x08190808, 0x19081919,
|
||||
0x0819082b, 0x19081919, 0x08191919, 0x19081919, 0x08192b08, 0x19081919, 0x082b0819, 0x19081919,
|
||||
0x082b1908, 0x19081919, 0x19080808, 0x19081919, 0x1908082b, 0x19081919, 0x19081919, 0x19081919,
|
||||
0x19082b08, 0x19081919, 0x19190819, 0x19081919, 0x19191908, 0x19081919, 0x192b0808, 0x19081919,
|
||||
0x192b2b2b, 0x19081919, 0x2b080819, 0x19081919, 0x2b081908, 0x19081919, 0x2b190808, 0x19081919,
|
||||
0x08080808, 0x1908192b, 0x0808082b, 0x1908192b, 0x08081919, 0x1908192b, 0x08082b08, 0x1908192b,
|
||||
0x08190819, 0x1908192b, 0x08191908, 0x1908192b, 0x082b0808, 0x1908192b, 0x19080819, 0x1908192b,
|
||||
0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x2b080808, 0x1908192b, 0x2b2b1919, 0x1908192b,
|
||||
0x08080819, 0x19082b08, 0x08081908, 0x19082b08, 0x08082b19, 0x19082b08, 0x08190808, 0x19082b08,
|
||||
0x0819082b, 0x19082b08, 0x08191919, 0x19082b08, 0x08192b08, 0x19082b08, 0x082b0819, 0x19082b08,
|
||||
0x082b1908, 0x19082b08, 0x19080808, 0x19082b08, 0x1908082b, 0x19082b08, 0x19081919, 0x19082b08,
|
||||
0x19082b08, 0x19082b08, 0x19190819, 0x19082b08, 0x19191908, 0x19082b08, 0x192b0808, 0x19082b08,
|
||||
0x2b081908, 0x19082b08, 0x2b190808, 0x19082b08, 0x08080808, 0x19082b19, 0x0808082b, 0x19082b19,
|
||||
0x08081919, 0x19082b19, 0x08082b08, 0x19082b19, 0x08190819, 0x19082b19, 0x08191908, 0x19082b19,
|
||||
0x082b0808, 0x19082b19, 0x19080819, 0x19082b19, 0x19081908, 0x19082b19, 0x19190808, 0x19082b19,
|
||||
0x2b080808, 0x19082b19, 0x2b19192b, 0x19082b19, 0x08080819, 0x19082b2b, 0x08081908, 0x19082b2b,
|
||||
0x08190808, 0x19082b2b, 0x19080808, 0x19082b2b, 0x08080808, 0x19190808, 0x0808082b, 0x19190808,
|
||||
0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, 0x08191908, 0x19190808,
|
||||
0x0819192b, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x082b082b, 0x19190808,
|
||||
0x082b1919, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, 0x19081908, 0x19190808,
|
||||
0x1908192b, 0x19190808, 0x19082b19, 0x19190808, 0x19190808, 0x19190808, 0x1919082b, 0x19190808,
|
||||
0x19191919, 0x19190808, 0x19192b08, 0x19190808, 0x192b0819, 0x19190808, 0x192b1908, 0x19190808,
|
||||
0x2b080808, 0x19190808, 0x2b08082b, 0x19190808, 0x2b081919, 0x19190808, 0x2b082b08, 0x19190808,
|
||||
0x2b190819, 0x19190808, 0x2b191908, 0x19190808, 0x08080819, 0x19190819, 0x08081908, 0x19190819,
|
||||
0x0808192b, 0x19190819, 0x08082b19, 0x19190819, 0x08190808, 0x19190819, 0x0819082b, 0x19190819,
|
||||
0x08191919, 0x19190819, 0x08192b08, 0x19190819, 0x082b0819, 0x19190819, 0x082b1908, 0x19190819,
|
||||
0x19080808, 0x19190819, 0x1908082b, 0x19190819, 0x19081919, 0x19190819, 0x19082b08, 0x19190819,
|
||||
0x19190819, 0x19190819, 0x19191908, 0x19190819, 0x192b0808, 0x19190819, 0x2b080819, 0x19190819,
|
||||
0x2b081908, 0x19190819, 0x2b190808, 0x19190819, 0x08080808, 0x1919082b, 0x08081919, 0x1919082b,
|
||||
0x08082b08, 0x1919082b, 0x08190819, 0x1919082b, 0x08191908, 0x1919082b, 0x082b0808, 0x1919082b,
|
||||
0x19080819, 0x1919082b, 0x19081908, 0x1919082b, 0x19190808, 0x1919082b, 0x192b2b19, 0x1919082b,
|
||||
0x2b080808, 0x1919082b, 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x0808192b, 0x19191908,
|
||||
0x08082b19, 0x19191908, 0x08190808, 0x19191908, 0x0819082b, 0x19191908, 0x08191919, 0x19191908,
|
||||
0x08192b08, 0x19191908, 0x082b0819, 0x19191908, 0x082b1908, 0x19191908, 0x19080808, 0x19191908,
|
||||
0x1908082b, 0x19191908, 0x19081919, 0x19191908, 0x19082b08, 0x19191908, 0x19190819, 0x19191908,
|
||||
0x19191908, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b081908, 0x19191908,
|
||||
0x2b190808, 0x19191908, 0x08080808, 0x19191919, 0x0808082b, 0x19191919, 0x08081919, 0x19191919,
|
||||
0x08082b08, 0x19191919, 0x08190819, 0x19191919, 0x08191908, 0x19191919, 0x082b0808, 0x19191919,
|
||||
0x19080819, 0x19191919, 0x19081908, 0x19191919, 0x19190808, 0x19191919, 0x2b080808, 0x19191919,
|
||||
0x08080819, 0x1919192b, 0x08081908, 0x1919192b, 0x08190808, 0x1919192b, 0x082b192b, 0x1919192b,
|
||||
0x19080808, 0x1919192b, 0x08080808, 0x19192b08, 0x0808082b, 0x19192b08, 0x08081919, 0x19192b08,
|
||||
0x08082b08, 0x19192b08, 0x08190819, 0x19192b08, 0x08191908, 0x19192b08, 0x082b0808, 0x19192b08,
|
||||
0x19080819, 0x19192b08, 0x19081908, 0x19192b08, 0x19190808, 0x19192b08, 0x19192b2b, 0x19192b08,
|
||||
0x2b080808, 0x19192b08, 0x08080819, 0x19192b19, 0x08081908, 0x19192b19, 0x08190808, 0x19192b19,
|
||||
0x19080808, 0x19192b19, 0x08080808, 0x19192b2b, 0x08192b19, 0x19192b2b, 0x2b081919, 0x19192b2b,
|
||||
0x2b2b2b08, 0x19192b2b, 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x0808192b, 0x192b0808,
|
||||
0x08190808, 0x192b0808, 0x0819082b, 0x192b0808, 0x08191919, 0x192b0808, 0x08192b08, 0x192b0808,
|
||||
0x082b0819, 0x192b0808, 0x082b1908, 0x192b0808, 0x19080808, 0x192b0808, 0x19081919, 0x192b0808,
|
||||
0x19082b08, 0x192b0808, 0x19190819, 0x192b0808, 0x19191908, 0x192b0808, 0x192b0808, 0x192b0808,
|
||||
0x2b081908, 0x192b0808, 0x2b190808, 0x192b0808, 0x08080808, 0x192b0819, 0x0808082b, 0x192b0819,
|
||||
0x08081919, 0x192b0819, 0x08082b08, 0x192b0819, 0x08190819, 0x192b0819, 0x08191908, 0x192b0819,
|
||||
0x082b0808, 0x192b0819, 0x19080819, 0x192b0819, 0x19081908, 0x192b0819, 0x19190808, 0x192b0819,
|
||||
0x2b080808, 0x192b0819, 0x2b192b19, 0x192b0819, 0x08081908, 0x192b082b, 0x08190808, 0x192b082b,
|
||||
0x19080808, 0x192b082b, 0x1919192b, 0x192b082b, 0x2b2b0819, 0x192b082b, 0x08080808, 0x192b1908,
|
||||
0x08081919, 0x192b1908, 0x08082b08, 0x192b1908, 0x08190819, 0x192b1908, 0x08191908, 0x192b1908,
|
||||
0x082b0808, 0x192b1908, 0x19080819, 0x192b1908, 0x19081908, 0x192b1908, 0x19190808, 0x192b1908,
|
||||
0x2b080808, 0x192b1908, 0x08080819, 0x192b1919, 0x08081908, 0x192b1919, 0x08190808, 0x192b1919,
|
||||
0x19080808, 0x192b1919, 0x19082b2b, 0x192b1919, 0x192b2b08, 0x192b1919, 0x2b19082b, 0x192b1919,
|
||||
0x08080808, 0x192b192b, 0x2b191908, 0x192b192b, 0x08080819, 0x192b2b08, 0x08081908, 0x192b2b08,
|
||||
0x08190808, 0x192b2b08, 0x192b1919, 0x192b2b08, 0x2b192b08, 0x192b2b08, 0x08080808, 0x192b2b19,
|
||||
0x082b2b2b, 0x192b2b19, 0x1908082b, 0x192b2b2b, 0x2b2b0819, 0x192b2b2b, 0x08080808, 0x2b080808,
|
||||
0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, 0x08190819, 0x2b080808,
|
||||
0x08191908, 0x2b080808, 0x08192b19, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b1919, 0x2b080808,
|
||||
0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x1919082b, 0x2b080808,
|
||||
0x19191919, 0x2b080808, 0x19192b08, 0x2b080808, 0x192b0819, 0x2b080808, 0x2b080808, 0x2b080808,
|
||||
0x2b081919, 0x2b080808, 0x2b190819, 0x2b080808, 0x2b191908, 0x2b080808, 0x08080819, 0x2b080819,
|
||||
0x08081908, 0x2b080819, 0x08082b19, 0x2b080819, 0x08190808, 0x2b080819, 0x0819082b, 0x2b080819,
|
||||
0x08191919, 0x2b080819, 0x08192b08, 0x2b080819, 0x082b0819, 0x2b080819, 0x082b1908, 0x2b080819,
|
||||
0x19080808, 0x2b080819, 0x1908082b, 0x2b080819, 0x19081919, 0x2b080819, 0x19082b08, 0x2b080819,
|
||||
0x19190819, 0x2b080819, 0x19191908, 0x2b080819, 0x2b080819, 0x2b080819, 0x2b081908, 0x2b080819,
|
||||
0x2b190808, 0x2b080819, 0x2b2b2b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x08081919, 0x2b08082b,
|
||||
0x08082b2b, 0x2b08082b, 0x08190819, 0x2b08082b, 0x08191908, 0x2b08082b, 0x19080819, 0x2b08082b,
|
||||
0x19081908, 0x2b08082b, 0x19190808, 0x2b08082b, 0x08080819, 0x2b081908, 0x08081908, 0x2b081908,
|
||||
0x0808192b, 0x2b081908, 0x08082b19, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908,
|
||||
0x08191919, 0x2b081908, 0x08192b08, 0x2b081908, 0x082b0819, 0x2b081908, 0x19080808, 0x2b081908,
|
||||
0x1908082b, 0x2b081908, 0x19081919, 0x2b081908, 0x19082b08, 0x2b081908, 0x19190819, 0x2b081908,
|
||||
0x19191908, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b080819, 0x2b081908, 0x2b081908, 0x2b081908,
|
||||
0x2b190808, 0x2b081908, 0x08080808, 0x2b081919, 0x0808082b, 0x2b081919, 0x08081919, 0x2b081919,
|
||||
0x08082b08, 0x2b081919, 0x08190819, 0x2b081919, 0x08191908, 0x2b081919, 0x082b0808, 0x2b081919,
|
||||
0x19080819, 0x2b081919, 0x19081908, 0x2b081919, 0x19190808, 0x2b081919, 0x2b080808, 0x2b081919,
|
||||
0x2b082b2b, 0x2b081919, 0x08080819, 0x2b08192b, 0x08081908, 0x2b08192b, 0x08190808, 0x2b08192b,
|
||||
0x082b2b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08081919, 0x2b082b08,
|
||||
0x08190819, 0x2b082b08, 0x08191908, 0x2b082b08, 0x19080819, 0x2b082b08, 0x19081908, 0x2b082b08,
|
||||
0x19190808, 0x2b082b08, 0x2b2b082b, 0x2b082b08, 0x08080819, 0x2b082b19, 0x08081908, 0x2b082b19,
|
||||
0x19080808, 0x2b082b19, 0x192b1919, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x19192b08, 0x2b082b2b,
|
||||
0x19192b2b, 0x2b082b2b, 0x2b08082b, 0x2b082b2b, 0x2b2b082b, 0x2b082b2b, 0x08080819, 0x2b190808,
|
||||
0x08081908, 0x2b190808, 0x08082b19, 0x2b190808, 0x08190808, 0x2b190808, 0x0819082b, 0x2b190808,
|
||||
0x08191919, 0x2b190808, 0x08192b08, 0x2b190808, 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808,
|
||||
0x1908082b, 0x2b190808, 0x19081919, 0x2b190808, 0x19082b08, 0x2b190808, 0x19190819, 0x2b190808,
|
||||
0x19191908, 0x2b190808, 0x192b0808, 0x2b190808, 0x2b080819, 0x2b190808, 0x2b081908, 0x2b190808,
|
||||
0x2b190808, 0x2b190808, 0x08080808, 0x2b190819, 0x08081919, 0x2b190819, 0x08190819, 0x2b190819,
|
||||
0x08191908, 0x2b190819, 0x19080819, 0x2b190819, 0x19081908, 0x2b190819, 0x19190808, 0x2b190819,
|
||||
0x19192b2b, 0x2b190819, 0x08080819, 0x2b19082b, 0x08081908, 0x2b19082b, 0x08190808, 0x2b19082b,
|
||||
0x19080808, 0x2b19082b, 0x2b2b192b, 0x2b19082b, 0x08080808, 0x2b191908, 0x0808082b, 0x2b191908,
|
||||
0x08081919, 0x2b191908, 0x08082b08, 0x2b191908, 0x08190819, 0x2b191908, 0x08191908, 0x2b191908,
|
||||
0x082b0808, 0x2b191908, 0x19080819, 0x2b191908, 0x19081908, 0x2b191908, 0x19190808, 0x2b191908,
|
||||
0x2b080808, 0x2b191908, 0x2b19192b, 0x2b191908, 0x08080819, 0x2b191919, 0x08081908, 0x2b191919,
|
||||
0x08190808, 0x2b191919, 0x19080808, 0x2b191919, 0x2b192b08, 0x2b191919, 0x2b2b0819, 0x2b191919,
|
||||
0x08080808, 0x2b19192b, 0x1908192b, 0x2b19192b, 0x192b1908, 0x2b19192b, 0x08080819, 0x2b192b08,
|
||||
0x08081908, 0x2b192b08, 0x08190808, 0x2b192b08, 0x082b192b, 0x2b192b08, 0x19080808, 0x2b192b08,
|
||||
0x2b2b2b19, 0x2b192b08, 0x08080808, 0x2b192b19, 0x19082b19, 0x2b192b19, 0x1919082b, 0x2b192b19,
|
||||
0x2b190808, 0x2b192b2b, 0x08080808, 0x2b2b0808, 0x08081919, 0x2b2b0808, 0x08082b2b, 0x2b2b0808,
|
||||
0x08191908, 0x2b2b0808, 0x082b082b, 0x2b2b0808, 0x082b2b2b, 0x2b2b0808, 0x19080819, 0x2b2b0808,
|
||||
0x19081908, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b2b082b, 0x2b2b0808, 0x2b2b2b2b, 0x2b2b0808,
|
||||
0x19080808, 0x2b2b0819, 0x192b1919, 0x2b2b0819, 0x0808082b, 0x2b2b082b, 0x08082b2b, 0x2b2b082b,
|
||||
0x082b082b, 0x2b2b082b, 0x082b2b08, 0x2b2b082b, 0x082b2b2b, 0x2b2b082b, 0x2b08082b, 0x2b2b082b,
|
||||
0x2b082b08, 0x2b2b082b, 0x2b082b2b, 0x2b2b082b, 0x2b2b2b08, 0x2b2b082b, 0x08080819, 0x2b2b1908,
|
||||
0x08081908, 0x2b2b1908, 0x08190808, 0x2b2b1908, 0x19080808, 0x2b2b1908, 0x2b082b19, 0x2b2b1908,
|
||||
0x2b2b1908, 0x2b2b1908, 0x08080808, 0x2b2b1919, 0x08192b19, 0x2b2b1919, 0x19190819, 0x2b2b192b,
|
||||
0x08082b2b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b082b, 0x2b2b2b08, 0x19191908, 0x2b2b2b19,
|
||||
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
|
||||
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
|
||||
);
|
||||
#enddecl(IQ2_S_GRID)
|
||||
|
||||
#decl(IQ3_XSS_GRID)
|
||||
|
||||
const iq3xxs_grid = array<u32, 256>(
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
||||
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
||||
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
||||
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
||||
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
||||
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
||||
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
||||
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
||||
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
||||
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
||||
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
||||
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
||||
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
||||
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
||||
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
||||
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
||||
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
||||
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
||||
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
||||
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
||||
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
||||
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
||||
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
||||
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
||||
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
||||
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
||||
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
||||
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
||||
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
||||
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
|
||||
);
|
||||
#enddecl(IQ3_XSS_GRID)
|
||||
|
||||
#decl(IQ3_S_GRID)
|
||||
|
||||
const iq3s_grid = array<u32, 512>(
|
||||
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
||||
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
||||
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
||||
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
||||
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
||||
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
||||
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
||||
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
||||
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
||||
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
||||
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
||||
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
||||
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
||||
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
||||
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
||||
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
||||
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
||||
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
||||
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
||||
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
||||
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
||||
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
||||
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
||||
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
||||
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
||||
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
||||
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
||||
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
||||
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
||||
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
||||
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
||||
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
||||
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
||||
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
||||
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
||||
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
||||
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
||||
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
||||
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
||||
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
||||
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
||||
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
||||
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
||||
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
||||
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
||||
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
||||
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
||||
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
||||
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
||||
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
||||
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
||||
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
||||
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
||||
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
||||
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
||||
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
||||
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
||||
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
||||
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
||||
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
||||
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
||||
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
||||
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
|
||||
);
|
||||
#enddecl(IQ3_S_GRID)
|
||||
|
||||
#decl(IQ1_GRID)
|
||||
|
||||
const IQ1_DELTA: f32 = 0.125;
|
||||
|
||||
const iq1_grid = array<u32, 1024>(
|
||||
0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,
|
||||
0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,
|
||||
0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,
|
||||
0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,
|
||||
0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,
|
||||
0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,
|
||||
0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,
|
||||
0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,
|
||||
0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,
|
||||
0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,
|
||||
0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,
|
||||
0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,
|
||||
0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,
|
||||
0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,
|
||||
0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,
|
||||
0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,
|
||||
0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,
|
||||
0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,
|
||||
0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,
|
||||
0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,
|
||||
0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,
|
||||
0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,
|
||||
0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,
|
||||
0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,
|
||||
0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,
|
||||
0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,
|
||||
0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,
|
||||
0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,
|
||||
0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,
|
||||
0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,
|
||||
0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,
|
||||
0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,
|
||||
0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,
|
||||
0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,
|
||||
0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,
|
||||
0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,
|
||||
0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,
|
||||
0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,
|
||||
0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,
|
||||
0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,
|
||||
0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,
|
||||
0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,
|
||||
0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,
|
||||
0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,
|
||||
0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,
|
||||
0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,
|
||||
0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,
|
||||
0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,
|
||||
0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,
|
||||
0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,
|
||||
0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,
|
||||
0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,
|
||||
0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,
|
||||
0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,
|
||||
0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,
|
||||
0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,
|
||||
0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,
|
||||
0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,
|
||||
0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,
|
||||
0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,
|
||||
0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,
|
||||
0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,
|
||||
0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,
|
||||
0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,
|
||||
0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,
|
||||
0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,
|
||||
0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,
|
||||
0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,
|
||||
0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,
|
||||
0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,
|
||||
0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,
|
||||
0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,
|
||||
0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,
|
||||
0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,
|
||||
0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,
|
||||
0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,
|
||||
0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,
|
||||
0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,
|
||||
0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,
|
||||
0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,
|
||||
0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,
|
||||
0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,
|
||||
0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,
|
||||
0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,
|
||||
0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,
|
||||
0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,
|
||||
0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,
|
||||
0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,
|
||||
0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,
|
||||
0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,
|
||||
0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,
|
||||
0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,
|
||||
0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,
|
||||
0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,
|
||||
0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,
|
||||
0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,
|
||||
0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,
|
||||
0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,
|
||||
0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,
|
||||
0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,
|
||||
0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,
|
||||
0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,
|
||||
0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,
|
||||
0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,
|
||||
0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,
|
||||
0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,
|
||||
0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,
|
||||
0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,
|
||||
0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,
|
||||
0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,
|
||||
0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,
|
||||
0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,
|
||||
0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,
|
||||
0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,
|
||||
0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,
|
||||
0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,
|
||||
0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,
|
||||
0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,
|
||||
0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,
|
||||
0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,
|
||||
0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,
|
||||
0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,
|
||||
0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,
|
||||
0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,
|
||||
0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,
|
||||
0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,
|
||||
0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,
|
||||
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
|
||||
);
|
||||
|
||||
#enddecl(IQ1_GRID)
|
||||
|
||||
#decl(IQ4_GRID)
|
||||
|
||||
const kvalues_iq4nl = array<i32, 16>(
|
||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
||||
);
|
||||
|
||||
#enddecl(IQ4_GRID)
|
||||
|
|
@ -27,6 +27,26 @@ def replace_placeholders(shader_text, replacements):
|
|||
return shader_text
|
||||
|
||||
|
||||
def expand_includes(shader, input_dir):
|
||||
"""
|
||||
Replace #include "file" lines in the text with the contents of that file.
|
||||
Searches for files relative to input_dir.
|
||||
"""
|
||||
include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
|
||||
|
||||
def replacer(match):
|
||||
fname = match.group(1)
|
||||
file_path = os.path.join(input_dir, fname)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Included file not found: {file_path}")
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
included_code = f.read()
|
||||
# Recursively expand includes inside the included file
|
||||
return expand_includes(included_code, input_dir)
|
||||
|
||||
return include_pattern.sub(replacer, shader)
|
||||
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||
if output_dir:
|
||||
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
|
||||
|
|
@ -35,8 +55,9 @@ def write_shader(shader_name, shader_code, output_dir, outfile):
|
|||
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
||||
|
||||
|
||||
def generate_variants(shader_path, output_dir, outfile):
|
||||
shader_base_name = shader_path.split("/")[-1].split(".")[0]
|
||||
def generate_variants(fname, input_dir, output_dir, outfile):
|
||||
shader_path = os.path.join(input_dir, fname)
|
||||
shader_base_name = fname.split(".")[0]
|
||||
|
||||
with open(shader_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
|
@ -46,11 +67,21 @@ def generate_variants(shader_path, output_dir, outfile):
|
|||
except ValueError:
|
||||
write_shader(shader_base_name, text, output_dir, outfile)
|
||||
else:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
try:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
except ValueError:
|
||||
decls_map = {}
|
||||
|
||||
with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
|
||||
common_decls = f.read()
|
||||
decls_map.update(parse_decls(common_decls))
|
||||
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
for variant in variants:
|
||||
decls = variant["DECLS"]
|
||||
if "DECLS" in variant:
|
||||
decls = variant["DECLS"]
|
||||
else:
|
||||
decls = []
|
||||
decls_code = ""
|
||||
for key in decls:
|
||||
if key not in decls_map:
|
||||
|
|
@ -59,8 +90,16 @@ def generate_variants(shader_path, output_dir, outfile):
|
|||
|
||||
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
|
||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
|
||||
final_shader = expand_includes(final_shader, input_dir)
|
||||
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||
elif "TYPE_SUFFIX" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"]
|
||||
elif "TYPE" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
||||
else:
|
||||
output_name = shader_base_name
|
||||
write_shader(output_name, final_shader, output_dir, outfile)
|
||||
|
||||
|
||||
|
|
@ -78,7 +117,7 @@ def main():
|
|||
out.write("// Auto-generated shader embedding\n\n")
|
||||
for fname in sorted(os.listdir(args.input_dir)):
|
||||
if fname.endswith(".wgsl"):
|
||||
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
|
||||
generate_variants(fname, args.input_dir, args.output_dir, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,874 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "vec4<f32>",
|
||||
"TYPE_SUFFIX": "f32_vec",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"BLOCK_SIZE": 4
|
||||
},
|
||||
"DECLS": ["F32_VEC"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["F32"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["F16"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "i32",
|
||||
"DST_TYPE": "i32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["I32"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_1",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_1",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q8_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q2_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q3_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q6_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "iq2_xxs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "iq2_xs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq2_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq3_xxs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq3_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq1_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq1_m",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq4_nl",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq4_xs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(F32_VEC)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
|
||||
}
|
||||
#enddecl(F32_VEC)
|
||||
|
||||
#decl(F32)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = src[src_base + offset];
|
||||
}
|
||||
#enddecl(F32)
|
||||
|
||||
#decl(F16)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = f32(src[src_base + offset]);
|
||||
}
|
||||
#enddecl(F16)
|
||||
|
||||
#decl(I32)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = src[src_base + offset];
|
||||
}
|
||||
#enddecl(I32)
|
||||
|
||||
#decl(Q4_0)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_0 = src[src_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_0)
|
||||
|
||||
#decl(Q4_1)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_1 = src[src_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
let m = f32(block_q4_1.m);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = block_q4_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
let q_lo = f32(q_byte & 0xF) * d + m;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_1)
|
||||
|
||||
#decl(Q5_0)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_0 = src[src_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(Q5_0)
|
||||
|
||||
#decl(Q5_1)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_1 = src[src_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
let m = f32(block_q5_1.m);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = block_q5_1.qs[j];
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
|
||||
let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q5_1)
|
||||
|
||||
#decl(Q8_0)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q8_0 = src[src_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q8_0)
|
||||
|
||||
#decl(Q2_K)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
for (var shift: u32 = 0; shift < 8; shift += 2) {
|
||||
// 2 blocks
|
||||
for (var k: u32 = 0; k < 32; k += 16) {
|
||||
let sc = get_byte(block.scales[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * f32(sc & 0xF);
|
||||
let ml = m * f32(sc >> 4);
|
||||
for (var l: u32 = 0u; l < 16; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
dst[dst_i] = (f32(qs_val) * dl - ml);
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q2_K)
|
||||
|
||||
#decl(Q3_K)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
}
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var m: u32 = 1;
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
for (var shift: u32 = 0; shift < 8; shift += 2) {
|
||||
// 2 blocks
|
||||
for (var k: u32 = 0; k < 32; k += 16) {
|
||||
let sc = get_byte(scale_vals[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * (f32(sc) - 32.0);
|
||||
for (var l: u32 = 0u; l < 16u; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
dst[dst_i] = (f32(qs_val) - hm) * dl;
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
m <<= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q3_K)
|
||||
|
||||
#decl(Q4_K)
|
||||
// 8 blocks of 32 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
// 2 blocks each iteration
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
|
||||
for (var shift: u32 = 0; shift < 8; shift += 4) {
|
||||
let scale_min = get_scale_min(is, block.scales);
|
||||
is++;
|
||||
let dl = d * scale_min.x;
|
||||
let ml = m * scale_min.y;
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qs_val = (q_byte >> shift) & 0xF;
|
||||
dst[dst_i] = (f32(qs_val) * dl - ml);
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_K)
|
||||
|
||||
#decl(Q5_K)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let m = f32(block.dmin);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var u: u32 = 1;
|
||||
// 2 blocks each iteration
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
|
||||
for (var shift: u32 = 0; shift < 8; shift += 4) {
|
||||
let scale_min = get_scale_min(is, block.scales);
|
||||
is++;
|
||||
let dl = d * scale_min.x;
|
||||
let ml = m * scale_min.y;
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
|
||||
let qh_byte = get_byte(block.qh[l / 4], l % 4);
|
||||
let qs_val = (q_byte >> shift) & 0xF;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml;
|
||||
dst_i++;
|
||||
}
|
||||
u <<= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q5_K)
|
||||
|
||||
#decl(Q6_K)
|
||||
// 16 blocks of 16 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
}
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
}
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var qh_b_idx: u32 = 0;
|
||||
var sc_b_idx: u32 = 0;
|
||||
for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
|
||||
for (var l: u32 = 0; l < 32; l++) {
|
||||
let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
|
||||
let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
|
||||
let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
|
||||
|
||||
let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
|
||||
let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
|
||||
let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
|
||||
let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
|
||||
|
||||
let is = l/16;
|
||||
let is1 = sc_b_idx + is;
|
||||
let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
|
||||
let is2 = sc_b_idx + is + 2;
|
||||
let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
|
||||
let is3 = sc_b_idx + is + 4;
|
||||
let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
|
||||
let is4 = sc_b_idx + is + 6;
|
||||
let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
|
||||
|
||||
dst[dst_i + l] = (q1 * f32(sc1)) * d;
|
||||
dst[dst_i + l + 32] = (q2 * f32(sc2)) * d;
|
||||
dst[dst_i + l + 64] = (q3 * f32(sc3)) * d;
|
||||
dst[dst_i + l + 96] = (q4 * f32(sc4)) * d;
|
||||
}
|
||||
dst_i += 128;
|
||||
qh_b_idx += 32;
|
||||
sc_b_idx += 8;
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(Q6_K)
|
||||
|
||||
#decl(IQ2_XXS)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
let is = (aux1 >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
dst[dst_i] = db * f32(g) * m;
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ2_XXS)
|
||||
|
||||
#decl(IQ2_XS)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
let db = array<f32, 2>(
|
||||
d * (0.5 + f32(s & 0xF)) * 0.25,
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let dl = db[l/2];
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
dst[dst_i] = dl * f32(g) * m;
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ2_XS)
|
||||
|
||||
#decl(IQ2_S)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
let db = array<f32, 2>(
|
||||
d * (0.5 + f32(s & 0xF)) * 0.25,
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
let qs_w = qs_vals[ib];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
|
||||
let ig = (get_byte(qs_w, l) | qh_b) * 8;
|
||||
let signs = get_byte(qs_vals[ib + 8], l);
|
||||
let dl = db[l/2];
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
|
||||
let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
|
||||
dst[dst_i] = dl * f32(g) * m;
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(IQ2_S)
|
||||
|
||||
#decl(IQ3_XSS)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let g1 = get_byte(iq3xxs_grid[ig1], j);
|
||||
let g2 = get_byte(iq3xxs_grid[ig2], j);
|
||||
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
|
||||
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
|
||||
dst[dst_i] = db * f32(g1) * m1;
|
||||
dst[dst_i + 4] = db * f32(g2) * m2;
|
||||
dst_i++;
|
||||
}
|
||||
dst_i += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ3_XSS)
|
||||
|
||||
#decl(IQ3_S)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
let db = array<f32, 2>(
|
||||
d * (1.0 + 2.0 * f32(s & 0xF)),
|
||||
d * (1.0 + 2.0 * f32(s >> 4))
|
||||
);
|
||||
for (var k: u32 = 0; k < 2; k++) {
|
||||
let dl = db[k];
|
||||
let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let g1 = get_byte(iq3s_grid[ig1], j);
|
||||
let g2 = get_byte(iq3s_grid[ig2], j);
|
||||
let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
|
||||
let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
|
||||
dst[dst_i] = dl * f32(g1) * m1;
|
||||
dst[dst_i + 4] = dl * f32(g2) * m2;
|
||||
dst_i++;
|
||||
}
|
||||
dst_i += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ3_S)
|
||||
|
||||
#decl(IQ1_S)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let gw = iq1_grid[(ig + j) / 16];
|
||||
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
|
||||
let gs = bitcast<i32>(g << 30) >> 30;
|
||||
dst[dst_i] = dl * (f32(gs) + delta);
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(IQ1_S)
|
||||
|
||||
#decl(IQ1_M)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
||||
let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
|
||||
let d = f32(bitcast<vec2<f16>>(scale).x);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
|
||||
let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
|
||||
let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
|
||||
var dl = array<f32, 2>(
|
||||
d * f32(2 * s1 + 1),
|
||||
d * f32(2 * s2 + 1)
|
||||
);
|
||||
|
||||
let qh = block.qh[ib / 2] >> (16 * (ib % 2));
|
||||
var idx = array<u32, 4>(
|
||||
get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
|
||||
get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
|
||||
get_byte(block.qs[ib], 2) | ((qh) & 0x700),
|
||||
get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
|
||||
);
|
||||
var delta = array<f32, 4>(
|
||||
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
|
||||
select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = idx[l] * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let gw = iq1_grid[(ig + j) / 16];
|
||||
let g = (gw >> (((ig + j) % 16) * 2)) & 3;
|
||||
let gs = bitcast<i32>(g << 30) >> 30;
|
||||
dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]);
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(IQ1_M)
|
||||
|
||||
#decl(IQ4_NL)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
var dst_i = dst_base + offset * 32;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]);
|
||||
dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]);
|
||||
dst_i++;
|
||||
}
|
||||
}
|
||||
#enddecl(IQ4_NL)
|
||||
|
||||
#decl(IQ4_XS)
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
|
||||
let dl = d * (f32(ls) - 32.0);
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let iqs = ib * 16 + j;
|
||||
let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
|
||||
dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]);
|
||||
dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]);
|
||||
dst_i++;
|
||||
}
|
||||
dst_i += 16;
|
||||
}
|
||||
}
|
||||
#enddecl(IQ4_XS)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx: array<i32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_idx: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_idx0: u32,
|
||||
stride_idx1: u32,
|
||||
stride_idx2: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Shape of dst
|
||||
ne0: u32,
|
||||
n_rows: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
|
||||
// Shape of idx
|
||||
idx1: u32,
|
||||
idx2: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
let i_dst3 = i / (params.ne2 * params.n_rows);
|
||||
|
||||
i = i % (params.ne2 * params.n_rows);
|
||||
let i_dst2 = i / params.n_rows;
|
||||
let i_dst1 = i % params.n_rows;
|
||||
|
||||
let i_idx2 = i_dst3 % params.idx2;
|
||||
let i_idx1 = i_dst2 % params.idx1;
|
||||
let i_idx0 = i_dst1;
|
||||
|
||||
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
|
||||
|
||||
let idx_val = u32(idx[i_idx]);
|
||||
|
||||
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
|
||||
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
|
||||
|
||||
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
|
||||
copy_elements(i_src_row, i_dst_row, i);
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,57 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Shape of src/dst
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
|
||||
eps: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// one thread per row
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1);
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
let i1 = i % params.ne1;
|
||||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
var sum = 0.0f;
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
sum += src[i_src_row + j] * src[i_src_row + j];
|
||||
}
|
||||
let eps = bitcast<f32>(params.eps);
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
dst[i_dst_row + j] = scale * src[i_src_row + j];
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> a: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride1: u32,
|
||||
stride2: u32,
|
||||
stride3: u32,
|
||||
|
||||
// Shape
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
|
||||
eps: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// one thread per row
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1);
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
let i1 = i % params.ne1;
|
||||
let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1;
|
||||
|
||||
var sum = 0.0f;
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
sum += a[i_row + j] * a[i_row + j];
|
||||
}
|
||||
let eps = bitcast<f32>(params.eps);
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
a[i_row + j] = scale * a[i_row + j];
|
||||
}
|
||||
}
|
||||
|
|
@ -52,7 +52,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
}
|
||||
var i = gid.x;
|
||||
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||
let i_dst3 = i / (params.ne2 * 3);
|
||||
|
||||
i = i % (params.ne2 * params.n_rows);
|
||||
let i_src2 = i / params.n_rows;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
#define LLAMA_MAX_SEQ 64
|
||||
#define LLAMA_MAX_SEQ 256
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ struct llama_hparams {
|
|||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
bool use_kq_norm = true;
|
||||
bool use_kq_norm = false;
|
||||
|
||||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_80M: return "80M";
|
||||
case LLM_TYPE_109M: return "109M";
|
||||
case LLM_TYPE_137M: return "137M";
|
||||
case LLM_TYPE_140M: return "140M";
|
||||
case LLM_TYPE_160M: return "160M";
|
||||
case LLM_TYPE_190M: return "190M";
|
||||
case LLM_TYPE_220M: return "220M";
|
||||
|
|
@ -44,6 +45,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_270M: return "270M";
|
||||
case LLM_TYPE_335M: return "335M";
|
||||
case LLM_TYPE_350M: return "350M";
|
||||
case LLM_TYPE_360M: return "360M";
|
||||
case LLM_TYPE_410M: return "410M";
|
||||
case LLM_TYPE_450M: return "450M";
|
||||
case LLM_TYPE_475M: return "475M";
|
||||
|
|
@ -51,6 +53,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_700M: return "700M";
|
||||
case LLM_TYPE_770M: return "770M";
|
||||
case LLM_TYPE_780M: return "780M";
|
||||
case LLM_TYPE_950M: return "950M";
|
||||
case LLM_TYPE_0_3B: return "0.3B";
|
||||
case LLM_TYPE_0_5B: return "0.5B";
|
||||
case LLM_TYPE_0_6B: return "0.6B";
|
||||
|
|
@ -623,19 +626,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
if (found_swa && hparams.n_swa == 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192;
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
}
|
||||
|
||||
switch (hparams.n_expert) {
|
||||
case 0: {
|
||||
// MobileLLM (no MoE)
|
||||
switch (hparams.n_embd) {
|
||||
case 2048: type = LLM_TYPE_140M; break;
|
||||
case 4096: type = LLM_TYPE_360M; break;
|
||||
case 6144: type = LLM_TYPE_950M; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case 16: type = LLM_TYPE_17B_16E; break;
|
||||
case 128: type = LLM_TYPE_17B_128E; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
if (type == LLM_TYPE_17B_128E) {
|
||||
hparams.use_kq_norm = false;
|
||||
}
|
||||
hparams.use_kq_norm = type != LLM_TYPE_17B_128E;
|
||||
} break;
|
||||
case LLM_ARCH_ARCEE:
|
||||
{
|
||||
|
|
@ -2548,9 +2564,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
|
||||
bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0;
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
||||
|
|
@ -6423,6 +6438,14 @@ struct llm_build_llama : public llm_graph_context {
|
|||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (hparams.use_kq_norm) {
|
||||
// Llama4TextL2Norm
|
||||
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
|
||||
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
|
|
@ -6530,7 +6553,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
(il + 1) % hparams.n_no_rope_layer_step != 0;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
|
|
@ -19399,7 +19423,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
} break;
|
||||
case LLM_ARCH_LLAMA4:
|
||||
{
|
||||
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
|
||||
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
|
||||
llm = std::make_unique<llm_build_llama>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ enum llm_type {
|
|||
LLM_TYPE_80M,
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_140M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
|
|
@ -36,6 +37,7 @@ enum llm_type {
|
|||
LLM_TYPE_270M,
|
||||
LLM_TYPE_335M,
|
||||
LLM_TYPE_350M,
|
||||
LLM_TYPE_360M,
|
||||
LLM_TYPE_410M,
|
||||
LLM_TYPE_450M,
|
||||
LLM_TYPE_475M,
|
||||
|
|
@ -43,6 +45,7 @@ enum llm_type {
|
|||
LLM_TYPE_700M,
|
||||
LLM_TYPE_770M,
|
||||
LLM_TYPE_780M,
|
||||
LLM_TYPE_950M,
|
||||
LLM_TYPE_0_3B,
|
||||
LLM_TYPE_0_5B,
|
||||
LLM_TYPE_0_6B,
|
||||
|
|
|
|||
|
|
@ -6071,6 +6071,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||
}
|
||||
|
||||
// single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||
|
||||
// fusion
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 2, 1, 1}, 3));
|
||||
|
|
@ -6325,12 +6329,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_sqr(type));
|
||||
test_cases.emplace_back(new test_sqrt(type));
|
||||
test_cases.emplace_back(new test_log(type));
|
||||
test_cases.emplace_back(new test_sin(type));
|
||||
test_cases.emplace_back(new test_cos(type));
|
||||
test_cases.emplace_back(new test_clamp(type));
|
||||
test_cases.emplace_back(new test_sqr (type));
|
||||
test_cases.emplace_back(new test_sqrt (type));
|
||||
test_cases.emplace_back(new test_log (type));
|
||||
test_cases.emplace_back(new test_sin (type));
|
||||
test_cases.emplace_back(new test_cos (type));
|
||||
test_cases.emplace_back(new test_clamp (type));
|
||||
test_cases.emplace_back(new test_leaky_relu(type));
|
||||
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -111,6 +111,7 @@ static bool server_task_type_need_logits(server_task_type task_type) {
|
|||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool include_usage = false;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
bool return_tokens = false;
|
||||
bool return_progress = false;
|
||||
|
|
@ -310,17 +311,19 @@ struct server_task {
|
|||
params.verbose = params_base.verbosity > 9;
|
||||
params.timings_per_token = json_value(data, "timings_per_token", false);
|
||||
|
||||
params.stream = json_value(data, "stream", false);
|
||||
params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||
params.return_tokens = json_value(data, "return_tokens", false);
|
||||
params.return_progress = json_value(data, "return_progress", false);
|
||||
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
params.stream = json_value(data, "stream", false);
|
||||
auto stream_opt = json_value(data, "stream_options", json::object());
|
||||
params.include_usage = json_value(stream_opt, "include_usage", false);
|
||||
params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||
params.return_tokens = json_value(data, "return_tokens", false);
|
||||
params.return_progress = json_value(data, "return_progress", false);
|
||||
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||
|
|
@ -775,6 +778,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
llama_tokens tokens;
|
||||
|
||||
bool stream;
|
||||
bool include_usage;
|
||||
result_timings timings;
|
||||
std::string prompt;
|
||||
|
||||
|
|
@ -982,21 +986,23 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
|
||||
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
||||
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
||||
deltas.push_back({
|
||||
{"choices", json::array()},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
});
|
||||
if (include_usage) {
|
||||
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
||||
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
||||
deltas.push_back({
|
||||
{"choices", json::array()},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
});
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas.back().push_back({"timings", timings.to_json()});
|
||||
|
|
@ -2815,6 +2821,7 @@ struct server_context {
|
|||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->include_usage = slot.params.include_usage;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
|
@ -5261,6 +5268,42 @@ int main(int argc, char ** argv) {
|
|||
svr->Get (params.api_prefix + "/slots", handle_slots);
|
||||
svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
|
||||
|
||||
// SPA fallback route - serve index.html for any route that doesn't match API endpoints
|
||||
// This enables client-side routing for dynamic routes like /chat/[id]
|
||||
if (params.webui && params.public_path.empty()) {
|
||||
// Only add fallback when using embedded static files
|
||||
svr->Get(".*", [](const httplib::Request & req, httplib::Response & res) {
|
||||
// Skip API routes - they should have been handled above
|
||||
if (req.path.find("/v1/") != std::string::npos ||
|
||||
req.path.find("/health") != std::string::npos ||
|
||||
req.path.find("/metrics") != std::string::npos ||
|
||||
req.path.find("/props") != std::string::npos ||
|
||||
req.path.find("/models") != std::string::npos ||
|
||||
req.path.find("/api/tags") != std::string::npos ||
|
||||
req.path.find("/completions") != std::string::npos ||
|
||||
req.path.find("/chat/completions") != std::string::npos ||
|
||||
req.path.find("/embeddings") != std::string::npos ||
|
||||
req.path.find("/tokenize") != std::string::npos ||
|
||||
req.path.find("/detokenize") != std::string::npos ||
|
||||
req.path.find("/lora-adapters") != std::string::npos ||
|
||||
req.path.find("/slots") != std::string::npos) {
|
||||
return false; // Let other handlers process API routes
|
||||
}
|
||||
|
||||
// Serve index.html for all other routes (SPA fallback)
|
||||
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
||||
} else {
|
||||
res.set_header("Content-Encoding", "gzip");
|
||||
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
||||
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
||||
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def test_no_webui():
|
|||
url = f"http://{server.server_host}:{server.server_port}"
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 200
|
||||
assert "<html>" in res.text
|
||||
assert "<!doctype html>" in res.text
|
||||
server.stop()
|
||||
|
||||
# with --no-webui
|
||||
|
|
|
|||
|
|
@ -271,8 +271,10 @@ def test_chat_completion_with_timings_per_token():
|
|||
"max_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"timings_per_token": True,
|
||||
})
|
||||
stats_received = False
|
||||
for i, data in enumerate(res):
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
|
|
@ -288,6 +290,8 @@ def test_chat_completion_with_timings_per_token():
|
|||
assert "predicted_per_second" in data["timings"]
|
||||
assert "predicted_n" in data["timings"]
|
||||
assert data["timings"]["predicted_n"] <= 10
|
||||
stats_received = True
|
||||
assert stats_received
|
||||
|
||||
|
||||
def test_logprobs():
|
||||
|
|
|
|||
|
|
@ -1,24 +1,27 @@
|
|||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
test-results
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
# Output
|
||||
.output
|
||||
.vercel
|
||||
.netlify
|
||||
.wrangler
|
||||
/.svelte-kit
|
||||
/build
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
Thumbs.db
|
||||
|
||||
# Env
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
!.env.test
|
||||
|
||||
# Vite
|
||||
vite.config.js.timestamp-*
|
||||
vite.config.ts.timestamp-*
|
||||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
engine-strict=true
|
||||
|
|
@ -1,10 +1,9 @@
|
|||
**/.vscode
|
||||
**/.github
|
||||
**/.git
|
||||
**/.svn
|
||||
**/.hg
|
||||
**/node_modules
|
||||
**/dist
|
||||
**/build
|
||||
# Package Managers
|
||||
package-lock.json
|
||||
pnpm-lock.yaml
|
||||
yarn.lock
|
||||
bun.lock
|
||||
bun.lockb
|
||||
|
||||
*.config.js
|
||||
# Miscellaneous
|
||||
/static/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"useTabs": true,
|
||||
"singleQuote": true,
|
||||
"trailingComma": "none",
|
||||
"printWidth": 100,
|
||||
"plugins": ["prettier-plugin-svelte", "prettier-plugin-tailwindcss"],
|
||||
"overrides": [
|
||||
{
|
||||
"files": "*.svelte",
|
||||
"options": {
|
||||
"parser": "svelte"
|
||||
}
|
||||
}
|
||||
],
|
||||
"tailwindStylesheet": "./src/app.css"
|
||||
}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
<script lang="ts">
|
||||
import { ModeWatcher } from 'mode-watcher';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children?: any;
|
||||
}
|
||||
|
||||
let { children }: Props = $props();
|
||||
|
||||
onMount(() => {
|
||||
const root = document.documentElement;
|
||||
const theme = localStorage.getItem('mode-watcher-mode') || 'system';
|
||||
|
||||
if (theme === 'dark') {
|
||||
root.classList.add('dark');
|
||||
} else if (theme === 'light') {
|
||||
root.classList.remove('dark');
|
||||
} else {
|
||||
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||
if (prefersDark) {
|
||||
root.classList.add('dark');
|
||||
} else {
|
||||
root.classList.remove('dark');
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<ModeWatcher />
|
||||
|
||||
{#if children}
|
||||
{@const Component = children}
|
||||
|
||||
<Component />
|
||||
{/if}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
<script lang="ts">
|
||||
import * as Tooltip from '../src/lib/components/ui/tooltip';
|
||||
|
||||
interface Props {
|
||||
children: any;
|
||||
}
|
||||
|
||||
let { children }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Provider>
|
||||
{@render children()}
|
||||
</Tooltip.Provider>
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
import type { StorybookConfig } from '@storybook/sveltekit';
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|ts|svelte)'],
|
||||
addons: [
|
||||
'@storybook/addon-svelte-csf',
|
||||
'@chromatic-com/storybook',
|
||||
'@storybook/addon-docs',
|
||||
'@storybook/addon-a11y',
|
||||
'@storybook/addon-vitest'
|
||||
],
|
||||
framework: {
|
||||
name: '@storybook/sveltekit',
|
||||
options: {}
|
||||
}
|
||||
};
|
||||
export default config;
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import type { Preview } from '@storybook/sveltekit';
|
||||
import '../src/app.css';
|
||||
import ModeWatcherDecorator from './ModeWatcherDecorator.svelte';
|
||||
import TooltipProviderDecorator from './TooltipProviderDecorator.svelte';
|
||||
|
||||
const preview: Preview = {
|
||||
parameters: {
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/i
|
||||
}
|
||||
},
|
||||
backgrounds: {
|
||||
disable: true
|
||||
}
|
||||
},
|
||||
decorators: [
|
||||
(story) => ({
|
||||
Component: ModeWatcherDecorator,
|
||||
props: {
|
||||
children: story
|
||||
}
|
||||
}),
|
||||
(story) => ({
|
||||
Component: TooltipProviderDecorator,
|
||||
props: {
|
||||
children: story
|
||||
}
|
||||
})
|
||||
]
|
||||
};
|
||||
|
||||
export default preview;
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import { setProjectAnnotations } from '@storybook/sveltekit';
|
||||
import * as previewAnnotations from './preview';
|
||||
import { beforeAll } from 'vitest';
|
||||
|
||||
const project = setProjectAnnotations([previewAnnotations]);
|
||||
|
||||
beforeAll(async () => {
|
||||
if (project.beforeAll) {
|
||||
await project.beforeAll();
|
||||
}
|
||||
});
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# llama.cpp Web UI
|
||||
|
||||
A modern, feature-rich web interface for llama.cpp built with SvelteKit. This UI provides an intuitive chat interface with advanced file handling, conversation management, and comprehensive model interaction capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **Modern Chat Interface** - Clean, responsive design with dark/light mode
|
||||
- **File Attachments** - Support for images, text files, PDFs, and audio with rich previews and drag-and-drop support
|
||||
- **Conversation Management** - Create, edit, branch, and search conversations
|
||||
- **Advanced Markdown** - Code highlighting, math formulas (KaTeX), and content blocks
|
||||
- **Reasoning Content** - Support for models with thinking blocks
|
||||
- **Keyboard Shortcuts** - Keyboard navigation (Shift+Ctrl/Cmd+O for new chat, Shift+Ctrl/Cmdt+E for edit conversation, Shift+Ctrl/Cmdt+D for delete conversation, Ctrl/Cmd+K for search, Ctrl/Cmd+V for paste, Ctrl/Cmd+B for opening/collapsing sidebar)
|
||||
- **Request Tracking** - Monitor processing with slots endpoint integration
|
||||
- **UI Testing** - Storybook component library with automated tests
|
||||
|
||||
## Development
|
||||
|
||||
Install dependencies:
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
Start the development server + Storybook:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
This will start both the SvelteKit dev server and Storybook on port 6006.
|
||||
|
||||
## Building
|
||||
|
||||
Create a production build:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
The build outputs static files to `../public` directory for deployment with llama.cpp server.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test suite:
|
||||
|
||||
```bash
|
||||
# E2E tests
|
||||
npm run test:e2e
|
||||
|
||||
# Unit tests
|
||||
npm run test:unit
|
||||
|
||||
# UI tests
|
||||
npm run test:ui
|
||||
|
||||
# All tests
|
||||
npm run test
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Framework**: SvelteKit with Svelte 5 runes
|
||||
- **Components**: ShadCN UI + bits-ui design system
|
||||
- **Database**: IndexedDB with Dexie for local storage
|
||||
- **Build**: Static adapter for deployment with llama.cpp server
|
||||
- **Testing**: Playwright (E2E) + Vitest (unit) + Storybook (components)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"$schema": "https://shadcn-svelte.com/schema.json",
|
||||
"tailwind": {
|
||||
"css": "src/app.css",
|
||||
"baseColor": "neutral"
|
||||
},
|
||||
"aliases": {
|
||||
"components": "$lib/components",
|
||||
"utils": "$lib/components/ui/utils",
|
||||
"ui": "$lib/components/ui",
|
||||
"hooks": "$lib/hooks",
|
||||
"lib": "$lib"
|
||||
},
|
||||
"typescript": true,
|
||||
"registry": "https://shadcn-svelte.com/registry"
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
import { expect, test } from '@playwright/test';
|
||||
|
||||
test('home page has expected h1', async ({ page }) => {
|
||||
await page.goto('/');
|
||||
await expect(page.locator('h1')).toBeVisible();
|
||||
});
|
||||
|
|
@ -1,26 +1,49 @@
|
|||
import js from '@eslint/js'
|
||||
import globals from 'globals'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
import tseslint from 'typescript-eslint'
|
||||
// For more info, see https://github.com/storybookjs/eslint-plugin-storybook#configuration-flat-config-format
|
||||
import storybook from 'eslint-plugin-storybook';
|
||||
|
||||
export default tseslint.config(
|
||||
{ ignores: ['dist'] },
|
||||
{
|
||||
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
},
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'react-refresh/only-export-components': 'off',
|
||||
'@typescript-eslint/no-unused-vars': 'off',
|
||||
},
|
||||
},
|
||||
)
|
||||
import prettier from 'eslint-config-prettier';
|
||||
import { includeIgnoreFile } from '@eslint/compat';
|
||||
import js from '@eslint/js';
|
||||
import svelte from 'eslint-plugin-svelte';
|
||||
import globals from 'globals';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import ts from 'typescript-eslint';
|
||||
import svelteConfig from './svelte.config.js';
|
||||
|
||||
const gitignorePath = fileURLToPath(new URL('./.gitignore', import.meta.url));
|
||||
|
||||
export default ts.config(
|
||||
includeIgnoreFile(gitignorePath),
|
||||
js.configs.recommended,
|
||||
...ts.configs.recommended,
|
||||
...svelte.configs.recommended,
|
||||
prettier,
|
||||
...svelte.configs.prettier,
|
||||
{
|
||||
languageOptions: {
|
||||
globals: { ...globals.browser, ...globals.node }
|
||||
},
|
||||
rules: {
|
||||
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
|
||||
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
|
||||
'no-undef': 'off',
|
||||
'svelte/no-at-html-tags': 'off'
|
||||
}
|
||||
},
|
||||
{
|
||||
files: ['**/*.svelte', '**/*.svelte.ts', '**/*.svelte.js'],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
projectService: true,
|
||||
extraFileExtensions: ['.svelte'],
|
||||
parser: ts.parser,
|
||||
svelteConfig
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
// Exclude Storybook files from main ESLint rules
|
||||
ignores: ['.storybook/**/*']
|
||||
},
|
||||
storybook.configs['flat/recommended']
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, maximum-scale=1"
|
||||
/>
|
||||
<meta name="color-scheme" content="light dark" />
|
||||
<title>🦙 llama.cpp - chat</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,66 +1,90 @@
|
|||
{
|
||||
"name": "webui",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "npm run format && tsc -b && vite build",
|
||||
"format": "eslint . && prettier --write .",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^2.2.0",
|
||||
"@sec-ant/readable-stream": "^0.6.0",
|
||||
"@tailwindcss/postcss": "^4.1.1",
|
||||
"@tailwindcss/vite": "^4.1.1",
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^5.0.12",
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"pdfjs-dist": "^5.2.133",
|
||||
"postcss": "^8.4.49",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-hot-toast": "^2.5.2",
|
||||
"react-markdown": "^9.0.3",
|
||||
"react-router": "^7.1.5",
|
||||
"rehype-highlight": "^7.0.2",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"tailwindcss": "^4.1.1",
|
||||
"textlinestream": "^1.1.1",
|
||||
"vite-plugin-singlefile": "^2.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.17.0",
|
||||
"@types/markdown-it": "^14.1.2",
|
||||
"@types/node": "^22.13.1",
|
||||
"@types/react": "^18.3.18",
|
||||
"@types/react-dom": "^18.3.5",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"eslint": "^9.17.0",
|
||||
"eslint-plugin-react-hooks": "^5.0.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.16",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^15.14.0",
|
||||
"prettier": "^3.4.2",
|
||||
"sass-embedded": "^1.83.4",
|
||||
"typescript": "~5.6.2",
|
||||
"typescript-eslint": "^8.18.2",
|
||||
"vite": "^6.0.5"
|
||||
},
|
||||
"prettier": {
|
||||
"trailingComma": "es5",
|
||||
"tabWidth": 2,
|
||||
"semi": true,
|
||||
"singleQuote": true,
|
||||
"bracketSameLine": false
|
||||
}
|
||||
"name": "webui",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite dev --host 0.0.0.0 & storybook dev -p 6006 --ci",
|
||||
"build": "vite build && ./scripts/post-build.sh",
|
||||
"preview": "vite preview",
|
||||
"prepare": "svelte-kit sync || echo ''",
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
|
||||
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
|
||||
"reset": "rm -rf .svelte-kit node_modules",
|
||||
"format": "prettier --write .",
|
||||
"lint": "prettier --check . && eslint .",
|
||||
"test": "npm run test:ui -- --run && npm run test:client -- --run && npm run test:server -- --run && npm run test:e2e",
|
||||
"test:e2e": "playwright test",
|
||||
"test:client": "vitest --project=client",
|
||||
"test:server": "vitest --project=server",
|
||||
"test:ui": "vitest --project=ui",
|
||||
"test:unit": "vitest",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^4.0.1",
|
||||
"@eslint/compat": "^1.2.5",
|
||||
"@eslint/js": "^9.18.0",
|
||||
"@internationalized/date": "^3.8.2",
|
||||
"@lucide/svelte": "^0.515.0",
|
||||
"@playwright/test": "^1.49.1",
|
||||
"@storybook/addon-a11y": "^9.0.17",
|
||||
"@storybook/addon-docs": "^9.0.17",
|
||||
"@storybook/addon-svelte-csf": "^5.0.7",
|
||||
"@storybook/addon-vitest": "^9.0.17",
|
||||
"@storybook/sveltekit": "^9.0.17",
|
||||
"@sveltejs/adapter-static": "^3.0.8",
|
||||
"@sveltejs/kit": "^2.22.0",
|
||||
"@sveltejs/vite-plugin-svelte": "^6.0.0",
|
||||
"@tailwindcss/forms": "^0.5.9",
|
||||
"@tailwindcss/typography": "^0.5.15",
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
"@types/node": "^22",
|
||||
"@vitest/browser": "^3.2.3",
|
||||
"bits-ui": "^2.8.11",
|
||||
"clsx": "^2.1.1",
|
||||
"dexie": "^4.0.11",
|
||||
"eslint": "^9.18.0",
|
||||
"eslint-config-prettier": "^10.0.1",
|
||||
"eslint-plugin-storybook": "^9.0.17",
|
||||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.11",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"storybook": "^9.0.17",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwind-variants": "^1.0.0",
|
||||
"tailwindcss": "^4.0.0",
|
||||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.0.4",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
"vitest": "^3.2.3",
|
||||
"vitest-browser-svelte": "^0.1.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"mode-watcher": "^1.1.0",
|
||||
"pdfjs-dist": "^5.4.54",
|
||||
"rehype-highlight": "^7.0.2",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark": "^15.0.1",
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-html": "^16.0.1",
|
||||
"remark-rehype": "^11.1.2",
|
||||
"svelte-sonner": "^1.0.5",
|
||||
"unist-util-visit": "^5.0.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
import { defineConfig } from '@playwright/test';
|
||||
|
||||
export default defineConfig({
|
||||
webServer: {
|
||||
command: 'npm run build && npx http-server ../public -p 8181',
|
||||
port: 8181
|
||||
},
|
||||
testDir: 'e2e'
|
||||
});
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
export default {
|
||||
plugins: {
|
||||
"@tailwindcss/postcss": {},
|
||||
},
|
||||
}
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
{
|
||||
"demo": true,
|
||||
"id": "conv-1734086746930",
|
||||
"lastModified": 1734087548943,
|
||||
"messages": [
|
||||
{
|
||||
"id": 1734086764521,
|
||||
"role": "user",
|
||||
"content": "this is a demo conversation, used in dev mode"
|
||||
},
|
||||
{
|
||||
"id": 1734087548327,
|
||||
"role": "assistant",
|
||||
"content": "This is the formula:\n\n$\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}$\n\nGiven an input vector \\(\\mathbf{x} = [x_1, x_2, \\ldots, x_n]\\)\n\n\\[\ny_i = \\frac{e^{x_i}}{\\sum_{j=1}^n e^{x_j}}\n\\]\n\n$2x + y = z$\n\nCode block latex:\n```latex\n\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}\n```\n\nTest dollar sign: $1234 $4567\n\nInvalid latex syntax: $E = mc^$ and $$E = mc^$$",
|
||||
"timings": {
|
||||
"prompt_n": 1,
|
||||
"prompt_ms": 28.923,
|
||||
"predicted_n": 25,
|
||||
"predicted_ms": 573.016
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1734087548328,
|
||||
"role": "user",
|
||||
"content": "this is a demo conversation, used in dev mode"
|
||||
},
|
||||
{
|
||||
"id": 1734087548329,
|
||||
"role": "assistant",
|
||||
"content": "Code block:\n```js\nconsole.log('hello world')\n```\n```sh\nls -la /dev\n```"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Script to install pre-commit and post-commit hooks for webui
|
||||
# Pre-commit: formats code and builds, stashes unstaged changes
|
||||
# Post-commit: automatically unstashes changes
|
||||
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
PRE_COMMIT_HOOK="$REPO_ROOT/.git/hooks/pre-commit"
|
||||
POST_COMMIT_HOOK="$REPO_ROOT/.git/hooks/post-commit"
|
||||
|
||||
echo "Installing pre-commit and post-commit hooks for webui..."
|
||||
|
||||
# Create the pre-commit hook
|
||||
cat > "$PRE_COMMIT_HOOK" << 'EOF'
|
||||
#!/bin/bash
|
||||
|
||||
# Check if there are any changes in the webui directory
|
||||
if git diff --cached --name-only | grep -q "^tools/server/webui/"; then
|
||||
echo "Formatting webui code..."
|
||||
|
||||
# Change to webui directory and run format
|
||||
cd tools/server/webui
|
||||
|
||||
# Check if npm is available and package.json exists
|
||||
if [ ! -f "package.json" ]; then
|
||||
echo "Error: package.json not found in tools/server/webui"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Stash any unstaged changes to avoid conflicts during format/build
|
||||
echo "Stashing unstaged changes..."
|
||||
git stash push --keep-index --include-untracked -m "Pre-commit hook: stashed unstaged changes"
|
||||
STASH_CREATED=$?
|
||||
|
||||
# Run the format command
|
||||
npm run format
|
||||
|
||||
# Check if format command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run format failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the check command
|
||||
npm run check
|
||||
|
||||
# Check if check command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run check failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the build command
|
||||
npm run build
|
||||
|
||||
# Check if build command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run build failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Go back to repo root to add build output
|
||||
cd ../../..
|
||||
|
||||
# Add the build output to staging area
|
||||
git add tools/server/public/index.html.gz
|
||||
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "✅ Build completed. Your unstaged changes have been stashed."
|
||||
echo "They will be automatically restored after the commit."
|
||||
# Create a marker file to indicate stash was created by pre-commit hook
|
||||
touch .git/WEBUI_STASH_MARKER
|
||||
fi
|
||||
|
||||
echo "Webui code formatted successfully"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
EOF
|
||||
|
||||
# Create the post-commit hook
|
||||
cat > "$POST_COMMIT_HOOK" << 'EOF'
|
||||
#!/bin/bash
|
||||
|
||||
# Check if we have a stash marker from the pre-commit hook
|
||||
if [ -f .git/WEBUI_STASH_MARKER ]; then
|
||||
echo "Restoring your unstaged changes..."
|
||||
git stash pop
|
||||
rm -f .git/WEBUI_STASH_MARKER
|
||||
echo "✅ Your unstaged changes have been restored."
|
||||
fi
|
||||
|
||||
exit 0
|
||||
EOF
|
||||
|
||||
# Make both hooks executable
|
||||
chmod +x "$PRE_COMMIT_HOOK"
|
||||
chmod +x "$POST_COMMIT_HOOK"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ Pre-commit and post-commit hooks installed successfully!"
|
||||
echo " Pre-commit: $PRE_COMMIT_HOOK"
|
||||
echo " Post-commit: $POST_COMMIT_HOOK"
|
||||
echo ""
|
||||
echo "The hooks will automatically:"
|
||||
echo " • Format and build webui code before commits"
|
||||
echo " • Stash unstaged changes during the process"
|
||||
echo " • Restore your unstaged changes after the commit"
|
||||
echo ""
|
||||
echo "To test the hooks, make a change to a file in the webui directory and commit it."
|
||||
else
|
||||
echo "❌ Failed to make hooks executable"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
rm -rf ../public/_app;
|
||||
rm ../public/favicon.svg;
|
||||
rm ../public/index.html;
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
import { HashRouter, Outlet, Route, Routes } from 'react-router';
|
||||
import Header from './components/Header';
|
||||
import Sidebar from './components/Sidebar';
|
||||
import { AppContextProvider, useAppContext } from './utils/app.context';
|
||||
import ChatScreen from './components/ChatScreen';
|
||||
import SettingDialog from './components/SettingDialog';
|
||||
import { Toaster } from 'react-hot-toast';
|
||||
import { ModalProvider } from './components/ModalProvider';
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<ModalProvider>
|
||||
<HashRouter>
|
||||
<div className="flex flex-row drawer lg:drawer-open">
|
||||
<AppContextProvider>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/chat/:convId" element={<ChatScreen />} />
|
||||
<Route path="*" element={<ChatScreen />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</AppContextProvider>
|
||||
</div>
|
||||
</HashRouter>
|
||||
</ModalProvider>
|
||||
);
|
||||
}
|
||||
|
||||
function AppLayout() {
|
||||
const { showSettings, setShowSettings } = useAppContext();
|
||||
return (
|
||||
<>
|
||||
<Sidebar />
|
||||
<main
|
||||
className="drawer-content grow flex flex-col h-screen mx-auto px-4 overflow-auto bg-base-100"
|
||||
id="main-scroll"
|
||||
>
|
||||
<Header />
|
||||
<Outlet />
|
||||
</main>
|
||||
{
|
||||
<SettingDialog
|
||||
show={showSettings}
|
||||
onClose={() => setShowSettings(false)}
|
||||
/>
|
||||
}
|
||||
<Toaster />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
import daisyuiThemes from 'daisyui/theme/object';
|
||||
import { isNumeric } from './utils/misc';
|
||||
|
||||
export const isDev = import.meta.env.MODE === 'development';
|
||||
|
||||
// constants
|
||||
export const BASE_URL = new URL('.', document.baseURI).href
|
||||
.toString()
|
||||
.replace(/\/$/, '');
|
||||
|
||||
export const CONFIG_DEFAULT = {
|
||||
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
|
||||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||
apiKey: '',
|
||||
systemMessage: '',
|
||||
showTokensPerSecond: false,
|
||||
showThoughtInProgress: false,
|
||||
excludeThoughtOnReq: true,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'edkypmxt',
|
||||
temperature: 0.8,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
top_k: 40,
|
||||
top_p: 0.95,
|
||||
min_p: 0.05,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typical_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
max_tokens: -1,
|
||||
custom: '', // custom json-stringified object
|
||||
// experimental features
|
||||
pyIntepreterEnabled: false,
|
||||
};
|
||||
export const CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
systemMessage: 'The starting message that defines how model should behave.',
|
||||
pasteLongTextToFileLen:
|
||||
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
|
||||
samplers:
|
||||
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
|
||||
temperature:
|
||||
'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
|
||||
dynatemp_range:
|
||||
'Addon for the temperature sampler. The added value to the range of dynamic temperature, which adjusts probabilities by entropy of tokens.',
|
||||
dynatemp_exponent:
|
||||
'Addon for the temperature sampler. Smoothes out the probability redistribution based on the most probable token.',
|
||||
top_k: 'Keeps only k top tokens.',
|
||||
top_p:
|
||||
'Limits tokens to those that together have a cumulative probability of at least p',
|
||||
min_p:
|
||||
'Limits tokens based on the minimum probability for a token to be considered, relative to the probability of the most likely token.',
|
||||
xtc_probability:
|
||||
'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.',
|
||||
xtc_threshold:
|
||||
'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.',
|
||||
typical_p:
|
||||
'Sorts and limits tokens based on the difference between log-probability and entropy.',
|
||||
repeat_last_n: 'Last n tokens to consider for penalizing repetition',
|
||||
repeat_penalty:
|
||||
'Controls the repetition of token sequences in the generated text',
|
||||
presence_penalty:
|
||||
'Limits tokens based on whether they appear in the output or not.',
|
||||
frequency_penalty:
|
||||
'Limits tokens based on how often they appear in the output.',
|
||||
dry_multiplier:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling multiplier.',
|
||||
dry_base:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling base value.',
|
||||
dry_allowed_length:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the allowed length for DRY sampling.',
|
||||
dry_penalty_last_n:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.',
|
||||
max_tokens: 'The maximum number of token per output.',
|
||||
custom: '', // custom json-stringified object
|
||||
};
|
||||
// config keys having numeric value (i.e. temperature, top_k, top_p, etc)
|
||||
export const CONFIG_NUMERIC_KEYS = Object.entries(CONFIG_DEFAULT)
|
||||
.filter((e) => isNumeric(e[1]))
|
||||
.map((e) => e[0]);
|
||||
// list of themes supported by daisyui
|
||||
export const THEMES = ['light', 'dark']
|
||||
// make sure light & dark are always at the beginning
|
||||
.concat(
|
||||
Object.keys(daisyuiThemes).filter((t) => t !== 'light' && t !== 'dark')
|
||||
);
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
@import 'tailwindcss';
|
||||
|
||||
@import 'tw-animate-css';
|
||||
|
||||
@custom-variant dark (&:is(.dark *));
|
||||
|
||||
:root {
|
||||
--radius: 0.625rem;
|
||||
--background: oklch(1 0 0);
|
||||
--foreground: oklch(0.145 0 0);
|
||||
--card: oklch(1 0 0);
|
||||
--card-foreground: oklch(0.145 0 0);
|
||||
--popover: oklch(1 0 0);
|
||||
--popover-foreground: oklch(0.145 0 0);
|
||||
--primary: oklch(0.205 0 0);
|
||||
--primary-foreground: oklch(0.985 0 0);
|
||||
--secondary: oklch(0.97 0 0);
|
||||
--secondary-foreground: oklch(0.205 0 0);
|
||||
--muted: oklch(0.97 0 0);
|
||||
--muted-foreground: oklch(0.556 0 0);
|
||||
--accent: oklch(0.97 0 0);
|
||||
--accent-foreground: oklch(0.205 0 0);
|
||||
--destructive: oklch(0.577 0.245 27.325);
|
||||
--border: oklch(0.875 0 0);
|
||||
--input: oklch(0.92 0 0);
|
||||
--ring: oklch(0.708 0 0);
|
||||
--chart-1: oklch(0.646 0.222 41.116);
|
||||
--chart-2: oklch(0.6 0.118 184.704);
|
||||
--chart-3: oklch(0.398 0.07 227.392);
|
||||
--chart-4: oklch(0.828 0.189 84.429);
|
||||
--chart-5: oklch(0.769 0.188 70.08);
|
||||
--sidebar: oklch(0.985 0 0);
|
||||
--sidebar-foreground: oklch(0.145 0 0);
|
||||
--sidebar-primary: oklch(0.205 0 0);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
--sidebar-accent: oklch(0.97 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||
--sidebar-border: oklch(0.922 0 0);
|
||||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.225 0 0);
|
||||
--code-foreground: oklch(0.875 0 0);
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: oklch(0.16 0 0);
|
||||
--foreground: oklch(0.985 0 0);
|
||||
--card: oklch(0.205 0 0);
|
||||
--card-foreground: oklch(0.985 0 0);
|
||||
--popover: oklch(0.205 0 0);
|
||||
--popover-foreground: oklch(0.985 0 0);
|
||||
--primary: oklch(0.922 0 0);
|
||||
--primary-foreground: oklch(0.205 0 0);
|
||||
--secondary: oklch(0.269 0 0);
|
||||
--secondary-foreground: oklch(0.985 0 0);
|
||||
--muted: oklch(0.269 0 0);
|
||||
--muted-foreground: oklch(0.708 0 0);
|
||||
--accent: oklch(0.269 0 0);
|
||||
--accent-foreground: oklch(0.985 0 0);
|
||||
--destructive: oklch(0.704 0.191 22.216);
|
||||
--border: oklch(1 0 0 / 30%);
|
||||
--input: oklch(1 0 0 / 30%);
|
||||
--ring: oklch(0.556 0 0);
|
||||
--chart-1: oklch(0.488 0.243 264.376);
|
||||
--chart-2: oklch(0.696 0.17 162.48);
|
||||
--chart-3: oklch(0.769 0.188 70.08);
|
||||
--chart-4: oklch(0.627 0.265 303.9);
|
||||
--chart-5: oklch(0.645 0.246 16.439);
|
||||
--sidebar: oklch(0.205 0 0);
|
||||
--sidebar-foreground: oklch(0.985 0 0);
|
||||
--sidebar-primary: oklch(0.488 0.243 264.376);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
--sidebar-accent: oklch(0.269 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.985 0 0);
|
||||
--sidebar-border: oklch(1 0 0 / 10%);
|
||||
--sidebar-ring: oklch(0.556 0 0);
|
||||
}
|
||||
|
||||
@theme inline {
|
||||
--radius-sm: calc(var(--radius) - 4px);
|
||||
--radius-md: calc(var(--radius) - 2px);
|
||||
--radius-lg: var(--radius);
|
||||
--radius-xl: calc(var(--radius) + 4px);
|
||||
--color-background: var(--background);
|
||||
--color-foreground: var(--foreground);
|
||||
--color-card: var(--card);
|
||||
--color-card-foreground: var(--card-foreground);
|
||||
--color-popover: var(--popover);
|
||||
--color-popover-foreground: var(--popover-foreground);
|
||||
--color-primary: var(--primary);
|
||||
--color-primary-foreground: var(--primary-foreground);
|
||||
--color-secondary: var(--secondary);
|
||||
--color-secondary-foreground: var(--secondary-foreground);
|
||||
--color-muted: var(--muted);
|
||||
--color-muted-foreground: var(--muted-foreground);
|
||||
--color-accent: var(--accent);
|
||||
--color-accent-foreground: var(--accent-foreground);
|
||||
--color-destructive: var(--destructive);
|
||||
--color-border: var(--border);
|
||||
--color-input: var(--input);
|
||||
--color-ring: var(--ring);
|
||||
--color-chart-1: var(--chart-1);
|
||||
--color-chart-2: var(--chart-2);
|
||||
--color-chart-3: var(--chart-3);
|
||||
--color-chart-4: var(--chart-4);
|
||||
--color-chart-5: var(--chart-5);
|
||||
--color-sidebar: var(--sidebar);
|
||||
--color-sidebar-foreground: var(--sidebar-foreground);
|
||||
--color-sidebar-primary: var(--sidebar-primary);
|
||||
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
|
||||
--color-sidebar-accent: var(--sidebar-accent);
|
||||
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
|
||||
--color-sidebar-border: var(--sidebar-border);
|
||||
--color-sidebar-ring: var(--sidebar-ring);
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border outline-ring/50;
|
||||
}
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
// See https://svelte.dev/docs/kit/types#app.d.ts
|
||||
// for information about these interfaces
|
||||
|
||||
// Import chat types from dedicated module
|
||||
|
||||
import type {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatMessageData,
|
||||
ApiChatMessageContentPart,
|
||||
ApiContextSizeError,
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiProcessingState
|
||||
} from '$lib/types/api';
|
||||
|
||||
import type {
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageTimings
|
||||
} from '$lib/types/chat';
|
||||
|
||||
import type {
|
||||
DatabaseConversation,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile
|
||||
} from '$lib/types/database';
|
||||
|
||||
import type {
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType
|
||||
} from '$lib/types/settings';
|
||||
|
||||
declare global {
|
||||
// namespace App {
|
||||
// interface Error {}
|
||||
// interface Locals {}
|
||||
// interface PageData {}
|
||||
// interface PageState {}
|
||||
// interface Platform {}
|
||||
// }
|
||||
|
||||
export {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatMessageData,
|
||||
ApiChatMessageContentPart,
|
||||
ApiContextSizeError,
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiProcessingState,
|
||||
ChatMessageData,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessageTimings,
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
DatabaseConversation,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType,
|
||||
SettingsChatServiceOptions
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%sveltekit.assets%/favicon.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
%sveltekit.head%
|
||||
</head>
|
||||
<body data-sveltekit-preload-data="hover">
|
||||
<div style="display: contents">%sveltekit.body%</div>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
import { useEffect, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { OpenInNewTab, XCloseButton } from '../utils/common';
|
||||
import { CanvasType } from '../utils/types';
|
||||
import { PlayIcon, StopIcon } from '@heroicons/react/24/outline';
|
||||
import StorageUtils from '../utils/storage';
|
||||
|
||||
const canInterrupt = typeof SharedArrayBuffer === 'function';
|
||||
|
||||
// adapted from https://pyodide.org/en/stable/usage/webworker.html
|
||||
const WORKER_CODE = `
|
||||
importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js");
|
||||
|
||||
let stdOutAndErr = [];
|
||||
|
||||
let pyodideReadyPromise = loadPyodide({
|
||||
stdout: (data) => stdOutAndErr.push(data),
|
||||
stderr: (data) => stdOutAndErr.push(data),
|
||||
});
|
||||
|
||||
let alreadySetBuff = false;
|
||||
|
||||
self.onmessage = async (event) => {
|
||||
stdOutAndErr = [];
|
||||
|
||||
// make sure loading is done
|
||||
const pyodide = await pyodideReadyPromise;
|
||||
const { id, python, context, interruptBuffer } = event.data;
|
||||
|
||||
if (interruptBuffer && !alreadySetBuff) {
|
||||
pyodide.setInterruptBuffer(interruptBuffer);
|
||||
alreadySetBuff = true;
|
||||
}
|
||||
|
||||
// Now load any packages we need, run the code, and send the result back.
|
||||
await pyodide.loadPackagesFromImports(python);
|
||||
|
||||
// make a Python dictionary with the data from content
|
||||
const dict = pyodide.globals.get("dict");
|
||||
const globals = dict(Object.entries(context));
|
||||
try {
|
||||
self.postMessage({ id, running: true });
|
||||
// Execute the python code in this context
|
||||
const result = pyodide.runPython(python, { globals });
|
||||
self.postMessage({ result, id, stdOutAndErr });
|
||||
} catch (error) {
|
||||
self.postMessage({ error: error.message, id });
|
||||
}
|
||||
interruptBuffer[0] = 0;
|
||||
};
|
||||
`;
|
||||
|
||||
let worker: Worker;
|
||||
const interruptBuffer = canInterrupt
|
||||
? new Uint8Array(new SharedArrayBuffer(1))
|
||||
: null;
|
||||
|
||||
const startWorker = () => {
|
||||
if (!worker) {
|
||||
worker = new Worker(
|
||||
URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' }))
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if (StorageUtils.getConfig().pyIntepreterEnabled) {
|
||||
startWorker();
|
||||
}
|
||||
|
||||
const runCodeInWorker = (
|
||||
pyCode: string,
|
||||
callbackRunning: () => void
|
||||
): {
|
||||
donePromise: Promise<string>;
|
||||
interrupt: () => void;
|
||||
} => {
|
||||
startWorker();
|
||||
const id = Math.random() * 1e8;
|
||||
const context = {};
|
||||
if (interruptBuffer) {
|
||||
interruptBuffer[0] = 0;
|
||||
}
|
||||
|
||||
const donePromise = new Promise<string>((resolve) => {
|
||||
worker.onmessage = (event) => {
|
||||
const { error, stdOutAndErr, running } = event.data;
|
||||
if (id !== event.data.id) return;
|
||||
if (running) {
|
||||
callbackRunning();
|
||||
return;
|
||||
} else if (error) {
|
||||
resolve(error.toString());
|
||||
} else {
|
||||
resolve(stdOutAndErr.join('\n'));
|
||||
}
|
||||
};
|
||||
worker.postMessage({ id, python: pyCode, context, interruptBuffer });
|
||||
});
|
||||
|
||||
const interrupt = () => {
|
||||
console.log('Interrupting...');
|
||||
console.trace();
|
||||
if (interruptBuffer) {
|
||||
interruptBuffer[0] = 2;
|
||||
}
|
||||
};
|
||||
|
||||
return { donePromise, interrupt };
|
||||
};
|
||||
|
||||
export default function CanvasPyInterpreter() {
|
||||
const { canvasData, setCanvasData } = useAppContext();
|
||||
|
||||
const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation
|
||||
const [running, setRunning] = useState(false);
|
||||
const [output, setOutput] = useState('');
|
||||
const [interruptFn, setInterruptFn] = useState<() => void>();
|
||||
const [showStopBtn, setShowStopBtn] = useState(false);
|
||||
|
||||
const runCode = async (pycode: string) => {
|
||||
interruptFn?.();
|
||||
setRunning(true);
|
||||
setOutput('Loading Pyodide...');
|
||||
const { donePromise, interrupt } = runCodeInWorker(pycode, () => {
|
||||
setOutput('Running...');
|
||||
setShowStopBtn(canInterrupt);
|
||||
});
|
||||
setInterruptFn(() => interrupt);
|
||||
const out = await donePromise;
|
||||
setOutput(out);
|
||||
setRunning(false);
|
||||
setShowStopBtn(false);
|
||||
};
|
||||
|
||||
// run code on mount
|
||||
useEffect(() => {
|
||||
setCode(canvasData?.content ?? '');
|
||||
runCode(canvasData?.content ?? '');
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [canvasData?.content]);
|
||||
|
||||
if (canvasData?.type !== CanvasType.PY_INTERPRETER) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="card bg-base-200 w-full h-full shadow-xl">
|
||||
<div className="card-body">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<span className="text-lg font-bold">Python Interpreter</span>
|
||||
<XCloseButton
|
||||
className="bg-base-100"
|
||||
onClick={() => setCanvasData(null)}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-rows-3 gap-4 h-full">
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full h-full font-mono"
|
||||
value={code}
|
||||
onChange={(e) => setCode(e.target.value)}
|
||||
></textarea>
|
||||
<div className="font-mono flex flex-col row-span-2">
|
||||
<div className="flex items-center mb-2">
|
||||
<button
|
||||
className="btn btn-sm bg-base-100"
|
||||
onClick={() => runCode(code)}
|
||||
disabled={running}
|
||||
>
|
||||
<PlayIcon className="h-6 w-6" /> Run
|
||||
</button>
|
||||
{showStopBtn && (
|
||||
<button
|
||||
className="btn btn-sm bg-base-100 ml-2"
|
||||
onClick={() => interruptFn?.()}
|
||||
>
|
||||
<StopIcon className="h-6 w-6" /> Stop
|
||||
</button>
|
||||
)}
|
||||
<span className="grow text-right text-xs">
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/11762">
|
||||
Report a bug
|
||||
</OpenInNewTab>
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-full dark-color"
|
||||
value={output}
|
||||
readOnly
|
||||
></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,135 +0,0 @@
|
|||
import {
|
||||
DocumentTextIcon,
|
||||
SpeakerWaveIcon,
|
||||
XMarkIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { MessageExtra } from '../utils/types';
|
||||
import { useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
|
||||
export default function ChatInputExtraContextItem({
|
||||
items,
|
||||
removeItem,
|
||||
clickToShow,
|
||||
}: {
|
||||
items?: MessageExtra[];
|
||||
removeItem?: (index: number) => void;
|
||||
clickToShow?: boolean;
|
||||
}) {
|
||||
const [show, setShow] = useState(-1);
|
||||
const showingItem = show >= 0 ? items?.[show] : undefined;
|
||||
|
||||
if (!items) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-row gap-4 overflow-x-auto py-2 px-1 mb-1"
|
||||
role="group"
|
||||
aria-description="Selected files"
|
||||
>
|
||||
{items.map((item, i) => (
|
||||
<div
|
||||
className="indicator"
|
||||
key={i}
|
||||
onClick={() => clickToShow && setShow(i)}
|
||||
tabIndex={0}
|
||||
aria-description={
|
||||
clickToShow ? `Click to show: ${item.name}` : undefined
|
||||
}
|
||||
role={clickToShow ? 'button' : 'menuitem'}
|
||||
>
|
||||
{removeItem && (
|
||||
<div className="indicator-item indicator-top">
|
||||
<button
|
||||
aria-label="Remove file"
|
||||
className="btn btn-neutral btn-sm w-4 h-4 p-0 rounded-full"
|
||||
onClick={() => removeItem(i)}
|
||||
>
|
||||
<XMarkIcon className="h-3 w-3" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={classNames({
|
||||
'flex flex-row rounded-md shadow-sm items-center m-0 p-0': true,
|
||||
'cursor-pointer hover:shadow-md': !!clickToShow,
|
||||
})}
|
||||
>
|
||||
{item.type === 'imageFile' ? (
|
||||
<>
|
||||
<img
|
||||
src={item.base64Url}
|
||||
alt={`Preview image for ${item.name}`}
|
||||
className="w-14 h-14 object-cover rounded-md"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div
|
||||
className="w-14 h-14 flex items-center justify-center"
|
||||
aria-description="Document icon"
|
||||
>
|
||||
{item.type === 'audioFile' ? (
|
||||
<SpeakerWaveIcon className="h-8 w-8 text-gray-500" />
|
||||
) : (
|
||||
<DocumentTextIcon className="h-8 w-8 text-gray-500" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="text-xs pr-4">
|
||||
<b>{item.name ?? 'Extra content'}</b>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{showingItem && (
|
||||
<dialog
|
||||
className="modal modal-open"
|
||||
aria-description={`Preview ${showingItem.name}`}
|
||||
>
|
||||
<div className="modal-box">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<b>{showingItem.name ?? 'Extra content'}</b>
|
||||
<button
|
||||
className="btn btn-ghost btn-sm"
|
||||
aria-label="Close preview dialog"
|
||||
>
|
||||
<XMarkIcon className="h-5 w-5" onClick={() => setShow(-1)} />
|
||||
</button>
|
||||
</div>
|
||||
{showingItem.type === 'imageFile' ? (
|
||||
<img
|
||||
src={showingItem.base64Url}
|
||||
alt={`Preview image for ${showingItem.name}`}
|
||||
/>
|
||||
) : showingItem.type === 'audioFile' ? (
|
||||
<audio
|
||||
controls
|
||||
className="w-full"
|
||||
aria-description={`Audio file ${showingItem.name}`}
|
||||
>
|
||||
<source
|
||||
src={`data:${showingItem.mimeType};base64,${showingItem.base64Data}`}
|
||||
type={showingItem.mimeType}
|
||||
aria-description={`Audio file ${showingItem.name}`}
|
||||
/>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<pre className="whitespace-pre-wrap break-words text-sm">
|
||||
{showingItem.content}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="modal-backdrop" onClick={() => setShow(-1)}></div>
|
||||
</dialog>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,320 +0,0 @@
|
|||
import { useMemo, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { Message, PendingMessage } from '../utils/types';
|
||||
import { classNames } from '../utils/misc';
|
||||
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
|
||||
import {
|
||||
ArrowPathIcon,
|
||||
ChevronLeftIcon,
|
||||
ChevronRightIcon,
|
||||
PencilSquareIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import ChatInputExtraContextItem from './ChatInputExtraContextItem';
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
|
||||
interface SplitMessage {
|
||||
content: PendingMessage['content'];
|
||||
thought?: string;
|
||||
isThinking?: boolean;
|
||||
}
|
||||
|
||||
export default function ChatMessage({
|
||||
msg,
|
||||
siblingLeafNodeIds,
|
||||
siblingCurrIdx,
|
||||
id,
|
||||
onRegenerateMessage,
|
||||
onEditMessage,
|
||||
onChangeSibling,
|
||||
isPending,
|
||||
}: {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
id?: string;
|
||||
onRegenerateMessage(msg: Message): void;
|
||||
onEditMessage(msg: Message, content: string): void;
|
||||
onChangeSibling(sibling: Message['id']): void;
|
||||
isPending?: boolean;
|
||||
}) {
|
||||
const { viewingChat, config } = useAppContext();
|
||||
const [editingContent, setEditingContent] = useState<string | null>(null);
|
||||
const timings = useMemo(
|
||||
() =>
|
||||
msg.timings
|
||||
? {
|
||||
...msg.timings,
|
||||
prompt_per_second:
|
||||
(msg.timings.prompt_n / msg.timings.prompt_ms) * 1000,
|
||||
predicted_per_second:
|
||||
(msg.timings.predicted_n / msg.timings.predicted_ms) * 1000,
|
||||
}
|
||||
: null,
|
||||
[msg.timings]
|
||||
);
|
||||
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
|
||||
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
|
||||
|
||||
// for reasoning model, we split the message into content and thought
|
||||
// TODO: implement this as remark/rehype plugin in the future
|
||||
const { content, thought, isThinking }: SplitMessage = useMemo(() => {
|
||||
if (msg.content === null || msg.role !== 'assistant') {
|
||||
return { content: msg.content };
|
||||
}
|
||||
const REGEX_THINK_OPEN = /<think>|<\|channel\|>analysis<\|message\|>/;
|
||||
const REGEX_THINK_CLOSE = /<\/think>|<\|end\|>/;
|
||||
let actualContent = '';
|
||||
let thought = '';
|
||||
let isThinking = false;
|
||||
let thinkSplit = msg.content.split(REGEX_THINK_OPEN, 2);
|
||||
actualContent += thinkSplit[0];
|
||||
while (thinkSplit[1] !== undefined) {
|
||||
// <think> tag found
|
||||
thinkSplit = thinkSplit[1].split(REGEX_THINK_CLOSE, 2);
|
||||
thought += thinkSplit[0];
|
||||
isThinking = true;
|
||||
if (thinkSplit[1] !== undefined) {
|
||||
// </think> closing tag found
|
||||
isThinking = false;
|
||||
thinkSplit = thinkSplit[1].split(REGEX_THINK_OPEN, 2);
|
||||
actualContent += thinkSplit[0];
|
||||
}
|
||||
}
|
||||
return { content: actualContent, thought, isThinking };
|
||||
}, [msg]);
|
||||
|
||||
if (!viewingChat) return null;
|
||||
|
||||
const isUser = msg.role === 'user';
|
||||
|
||||
return (
|
||||
<div
|
||||
className="group"
|
||||
id={id}
|
||||
role="group"
|
||||
aria-description={`Message from ${msg.role}`}
|
||||
>
|
||||
<div
|
||||
className={classNames({
|
||||
chat: true,
|
||||
'chat-start': !isUser,
|
||||
'chat-end': isUser,
|
||||
})}
|
||||
>
|
||||
{msg.extra && msg.extra.length > 0 && (
|
||||
<ChatInputExtraContextItem items={msg.extra} clickToShow />
|
||||
)}
|
||||
|
||||
<div
|
||||
className={classNames({
|
||||
'chat-bubble markdown': true,
|
||||
'chat-bubble bg-transparent': !isUser,
|
||||
})}
|
||||
>
|
||||
{/* textarea for editing message */}
|
||||
{editingContent !== null && (
|
||||
<>
|
||||
<textarea
|
||||
dir="auto"
|
||||
className="textarea textarea-bordered bg-base-100 text-base-content max-w-2xl w-[calc(90vw-8em)] h-24"
|
||||
value={editingContent}
|
||||
onChange={(e) => setEditingContent(e.target.value)}
|
||||
></textarea>
|
||||
<br />
|
||||
<button
|
||||
className="btn btn-ghost mt-2 mr-2"
|
||||
onClick={() => setEditingContent(null)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn mt-2"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
setEditingContent(null);
|
||||
onEditMessage(msg as Message, editingContent);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
{/* not editing content, render message */}
|
||||
{editingContent === null && (
|
||||
<>
|
||||
{content === null ? (
|
||||
<>
|
||||
{/* show loading dots for pending message */}
|
||||
<span className="loading loading-dots loading-md"></span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{/* render message as markdown */}
|
||||
<div dir="auto" tabIndex={0}>
|
||||
{thought && (
|
||||
<ThoughtProcess
|
||||
isThinking={!!isThinking && !!isPending}
|
||||
content={thought}
|
||||
open={config.showThoughtInProgress}
|
||||
/>
|
||||
)}
|
||||
|
||||
<MarkdownDisplay
|
||||
content={content}
|
||||
isGenerating={isPending}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{/* render timings if enabled */}
|
||||
{timings && config.showTokensPerSecond && (
|
||||
<div className="dropdown dropdown-hover dropdown-top mt-2">
|
||||
<div
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="cursor-pointer font-semibold text-sm opacity-60"
|
||||
>
|
||||
Speed: {timings.predicted_per_second.toFixed(1)} t/s
|
||||
</div>
|
||||
<div className="dropdown-content bg-base-100 z-10 w-64 p-2 shadow mt-4">
|
||||
<b>Prompt</b>
|
||||
<br />- Tokens: {timings.prompt_n}
|
||||
<br />- Time: {timings.prompt_ms} ms
|
||||
<br />- Speed: {timings.prompt_per_second.toFixed(1)} t/s
|
||||
<br />
|
||||
<b>Generation</b>
|
||||
<br />- Tokens: {timings.predicted_n}
|
||||
<br />- Time: {timings.predicted_ms} ms
|
||||
<br />- Speed: {timings.predicted_per_second.toFixed(1)} t/s
|
||||
<br />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* actions for each message */}
|
||||
{msg.content !== null && (
|
||||
<div
|
||||
className={classNames({
|
||||
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
|
||||
'flex-row-reverse': msg.role === 'user',
|
||||
})}
|
||||
>
|
||||
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
|
||||
<div
|
||||
className="flex gap-1 items-center opacity-60 text-sm"
|
||||
role="navigation"
|
||||
aria-description={`Message version ${siblingCurrIdx + 1} of ${siblingLeafNodeIds.length}`}
|
||||
>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !prevSibling,
|
||||
})}
|
||||
onClick={() => prevSibling && onChangeSibling(prevSibling)}
|
||||
aria-label="Previous message version"
|
||||
>
|
||||
<ChevronLeftIcon className="h-4 w-4" />
|
||||
</button>
|
||||
<span>
|
||||
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
|
||||
</span>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !nextSibling,
|
||||
})}
|
||||
onClick={() => nextSibling && onChangeSibling(nextSibling)}
|
||||
aria-label="Next message version"
|
||||
>
|
||||
<ChevronRightIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* user message */}
|
||||
{msg.role === 'user' && (
|
||||
<BtnWithTooltips
|
||||
className="btn-mini w-8 h-8"
|
||||
onClick={() => setEditingContent(msg.content)}
|
||||
disabled={msg.content === null}
|
||||
tooltipsContent="Edit message"
|
||||
>
|
||||
<PencilSquareIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
)}
|
||||
{/* assistant message */}
|
||||
{msg.role === 'assistant' && (
|
||||
<>
|
||||
{!isPending && (
|
||||
<BtnWithTooltips
|
||||
className="btn-mini w-8 h-8"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
onRegenerateMessage(msg as Message);
|
||||
}
|
||||
}}
|
||||
disabled={msg.content === null}
|
||||
tooltipsContent="Regenerate response"
|
||||
>
|
||||
<ArrowPathIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<CopyButton className="btn-mini w-8 h-8" content={msg.content} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ThoughtProcess({
|
||||
isThinking,
|
||||
content,
|
||||
open,
|
||||
}: {
|
||||
isThinking: boolean;
|
||||
content: string;
|
||||
open: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
aria-label="Toggle thought process display"
|
||||
tabIndex={0}
|
||||
className={classNames({
|
||||
'collapse bg-none': true,
|
||||
})}
|
||||
>
|
||||
<input type="checkbox" defaultChecked={open} />
|
||||
<div className="collapse-title px-0">
|
||||
<div className="btn rounded-xl">
|
||||
{isThinking ? (
|
||||
<span>
|
||||
<span
|
||||
className="loading loading-spinner loading-md mr-2"
|
||||
style={{ verticalAlign: 'middle' }}
|
||||
></span>
|
||||
Thinking
|
||||
</span>
|
||||
) : (
|
||||
<>Thought Process</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className="collapse-content text-base-content/70 text-sm p-1"
|
||||
tabIndex={0}
|
||||
aria-description="Thought process content"
|
||||
>
|
||||
<div className="border-l-2 border-base-content/20 pl-4 mb-4">
|
||||
<MarkdownDisplay content={content} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,459 +0,0 @@
|
|||
import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||
import { classNames, cleanCurrentUrl } from '../utils/misc';
|
||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useVSCodeContext } from '../utils/llama-vscode';
|
||||
import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts';
|
||||
import {
|
||||
ArrowUpIcon,
|
||||
StopIcon,
|
||||
PaperClipIcon,
|
||||
} from '@heroicons/react/24/solid';
|
||||
import {
|
||||
ChatExtraContextApi,
|
||||
useChatExtraContext,
|
||||
} from './useChatExtraContext.tsx';
|
||||
import Dropzone from 'react-dropzone';
|
||||
import toast from 'react-hot-toast';
|
||||
import ChatInputExtraContextItem from './ChatInputExtraContextItem.tsx';
|
||||
import { scrollToBottom, useChatScroll } from './useChatScroll.tsx';
|
||||
|
||||
/**
|
||||
* A message display is a message node with additional information for rendering.
|
||||
* For example, siblings of the message node are stored as their last node (aka leaf node).
|
||||
*/
|
||||
export interface MessageDisplay {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
isPending?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* If the current URL contains "?m=...", prefill the message input with the value.
|
||||
* If the current URL contains "?q=...", prefill and SEND the message.
|
||||
*/
|
||||
const prefilledMsg = {
|
||||
content() {
|
||||
const url = new URL(window.location.href);
|
||||
return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
|
||||
},
|
||||
shouldSend() {
|
||||
const url = new URL(window.location.href);
|
||||
return url.searchParams.has('q');
|
||||
},
|
||||
clear() {
|
||||
cleanCurrentUrl(['m', 'q']);
|
||||
},
|
||||
};
|
||||
|
||||
function getListMessageDisplay(
|
||||
msgs: Readonly<Message[]>,
|
||||
leafNodeId: Message['id']
|
||||
): MessageDisplay[] {
|
||||
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
|
||||
const res: MessageDisplay[] = [];
|
||||
const nodeMap = new Map<Message['id'], Message>();
|
||||
for (const msg of msgs) {
|
||||
nodeMap.set(msg.id, msg);
|
||||
}
|
||||
// find leaf node from a message node
|
||||
const findLeafNode = (msgId: Message['id']): Message['id'] => {
|
||||
let currNode: Message | undefined = nodeMap.get(msgId);
|
||||
while (currNode) {
|
||||
if (currNode.children.length === 0) break;
|
||||
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
|
||||
}
|
||||
return currNode?.id ?? -1;
|
||||
};
|
||||
// traverse the current nodes
|
||||
for (const msg of currNodes) {
|
||||
const parentNode = nodeMap.get(msg.parent ?? -1);
|
||||
if (!parentNode) continue;
|
||||
const siblings = parentNode.children;
|
||||
if (msg.type !== 'root') {
|
||||
res.push({
|
||||
msg,
|
||||
siblingLeafNodeIds: siblings.map(findLeafNode),
|
||||
siblingCurrIdx: siblings.indexOf(msg.id),
|
||||
});
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
export default function ChatScreen() {
|
||||
const {
|
||||
viewingChat,
|
||||
sendMessage,
|
||||
isGenerating,
|
||||
stopGenerating,
|
||||
pendingMessages,
|
||||
canvasData,
|
||||
replaceMessageAndGenerate,
|
||||
} = useAppContext();
|
||||
|
||||
const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content());
|
||||
const extraContext = useChatExtraContext();
|
||||
useVSCodeContext(textarea, extraContext);
|
||||
|
||||
const msgListRef = useRef<HTMLDivElement>(null);
|
||||
useChatScroll(msgListRef);
|
||||
|
||||
// keep track of leaf node for rendering
|
||||
const [currNodeId, setCurrNodeId] = useState<number>(-1);
|
||||
const messages: MessageDisplay[] = useMemo(() => {
|
||||
if (!viewingChat) return [];
|
||||
else return getListMessageDisplay(viewingChat.messages, currNodeId);
|
||||
}, [currNodeId, viewingChat]);
|
||||
|
||||
const currConvId = viewingChat?.conv.id ?? null;
|
||||
const pendingMsg: PendingMessage | undefined =
|
||||
pendingMessages[currConvId ?? ''];
|
||||
|
||||
useEffect(() => {
|
||||
// reset to latest node when conversation changes
|
||||
setCurrNodeId(-1);
|
||||
// scroll to bottom when conversation changes
|
||||
scrollToBottom(false, 1);
|
||||
}, [currConvId]);
|
||||
|
||||
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
|
||||
if (currLeafNodeId) {
|
||||
setCurrNodeId(currLeafNodeId);
|
||||
}
|
||||
// useChatScroll will handle the auto scroll
|
||||
};
|
||||
|
||||
const sendNewMessage = async () => {
|
||||
const lastInpMsg = textarea.value();
|
||||
if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? '')) {
|
||||
toast.error('Please enter a message');
|
||||
return;
|
||||
}
|
||||
textarea.setValue('');
|
||||
scrollToBottom(false);
|
||||
setCurrNodeId(-1);
|
||||
// get the last message node
|
||||
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
|
||||
if (
|
||||
!(await sendMessage(
|
||||
currConvId,
|
||||
lastMsgNodeId,
|
||||
lastInpMsg,
|
||||
extraContext.items,
|
||||
onChunk
|
||||
))
|
||||
) {
|
||||
// restore the input message if failed
|
||||
textarea.setValue(lastInpMsg);
|
||||
}
|
||||
// OK
|
||||
extraContext.clearItems();
|
||||
};
|
||||
|
||||
// for vscode context
|
||||
textarea.refOnSubmit.current = sendNewMessage;
|
||||
|
||||
const handleEditMessage = async (msg: Message, content: string) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.id);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
content,
|
||||
msg.extra,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const handleRegenerateMessage = async (msg: Message) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.parent);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
null,
|
||||
msg.extra,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const hasCanvas = !!canvasData;
|
||||
|
||||
useEffect(() => {
|
||||
if (prefilledMsg.shouldSend()) {
|
||||
// send the prefilled message if needed
|
||||
sendNewMessage();
|
||||
} else {
|
||||
// otherwise, focus on the input
|
||||
textarea.focus();
|
||||
}
|
||||
prefilledMsg.clear();
|
||||
// no need to keep track of sendNewMessage
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [textarea.ref]);
|
||||
|
||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||
const pendingMsgDisplay: MessageDisplay[] =
|
||||
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
||||
? [
|
||||
{
|
||||
msg: pendingMsg,
|
||||
siblingLeafNodeIds: [],
|
||||
siblingCurrIdx: 0,
|
||||
isPending: true,
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames({
|
||||
'grid lg:gap-8 grow transition-[300ms]': true,
|
||||
'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
|
||||
'grid-cols-[1fr_0fr]': !hasCanvas,
|
||||
})}
|
||||
>
|
||||
<div
|
||||
className={classNames({
|
||||
'flex flex-col w-full max-w-[900px] mx-auto': true,
|
||||
'hidden lg:flex': hasCanvas, // adapted for mobile
|
||||
flex: !hasCanvas,
|
||||
})}
|
||||
>
|
||||
{/* chat messages */}
|
||||
<div id="messages-list" className="grow" ref={msgListRef}>
|
||||
<div className="mt-auto flex flex-col items-center">
|
||||
{/* placeholder to shift the message to the bottom */}
|
||||
{viewingChat ? (
|
||||
''
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-4">Send a message to start</div>
|
||||
<ServerInfo />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{[...messages, ...pendingMsgDisplay].map((msg) => (
|
||||
<ChatMessage
|
||||
key={msg.msg.id}
|
||||
msg={msg.msg}
|
||||
siblingLeafNodeIds={msg.siblingLeafNodeIds}
|
||||
siblingCurrIdx={msg.siblingCurrIdx}
|
||||
onRegenerateMessage={handleRegenerateMessage}
|
||||
onEditMessage={handleEditMessage}
|
||||
onChangeSibling={setCurrNodeId}
|
||||
isPending={msg.isPending}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* chat input */}
|
||||
<ChatInput
|
||||
textarea={textarea}
|
||||
extraContext={extraContext}
|
||||
onSend={sendNewMessage}
|
||||
onStop={() => stopGenerating(currConvId ?? '')}
|
||||
isGenerating={isGenerating(currConvId ?? '')}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
|
||||
{canvasData?.type === CanvasType.PY_INTERPRETER && (
|
||||
<CanvasPyInterpreter />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ServerInfo() {
|
||||
const { serverProps } = useAppContext();
|
||||
const modalities = [];
|
||||
if (serverProps?.modalities?.audio) {
|
||||
modalities.push('audio');
|
||||
}
|
||||
if (serverProps?.modalities?.vision) {
|
||||
modalities.push('vision');
|
||||
}
|
||||
return (
|
||||
<div
|
||||
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
|
||||
tabIndex={0}
|
||||
aria-description="Server information"
|
||||
>
|
||||
<div className="card-body">
|
||||
<b>Server Info</b>
|
||||
<p>
|
||||
<b>Model</b>: {serverProps?.model_path?.split(/(\\|\/)/).pop()}
|
||||
<br />
|
||||
<b>Build</b>: {serverProps?.build_info}
|
||||
<br />
|
||||
{modalities.length > 0 ? (
|
||||
<>
|
||||
<b>Supported modalities:</b> {modalities.join(', ')}
|
||||
</>
|
||||
) : (
|
||||
''
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ChatInput({
|
||||
textarea,
|
||||
extraContext,
|
||||
onSend,
|
||||
onStop,
|
||||
isGenerating,
|
||||
}: {
|
||||
textarea: ChatTextareaApi;
|
||||
extraContext: ChatExtraContextApi;
|
||||
onSend: () => void;
|
||||
onStop: () => void;
|
||||
isGenerating: boolean;
|
||||
}) {
|
||||
const { config } = useAppContext();
|
||||
const [isDrag, setIsDrag] = useState(false);
|
||||
|
||||
return (
|
||||
<div
|
||||
role="group"
|
||||
aria-label="Chat input"
|
||||
className={classNames({
|
||||
'flex items-end pt-8 pb-6 sticky bottom-0 bg-base-100': true,
|
||||
'opacity-50': isDrag, // simply visual feedback to inform user that the file will be accepted
|
||||
})}
|
||||
>
|
||||
<Dropzone
|
||||
noClick
|
||||
onDrop={(files: File[]) => {
|
||||
setIsDrag(false);
|
||||
extraContext.onFileAdded(files);
|
||||
}}
|
||||
onDragEnter={() => setIsDrag(true)}
|
||||
onDragLeave={() => setIsDrag(false)}
|
||||
multiple={true}
|
||||
>
|
||||
{({ getRootProps, getInputProps }) => (
|
||||
<div
|
||||
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
|
||||
// when a file is pasted to the input, we handle it here
|
||||
// if a text is pasted, and if it is long text, we will convert it to a file
|
||||
onPasteCapture={(e: ClipboardEvent<HTMLInputElement>) => {
|
||||
const text = e.clipboardData.getData('text/plain');
|
||||
if (
|
||||
text.length > 0 &&
|
||||
config.pasteLongTextToFileLen > 0 &&
|
||||
text.length > config.pasteLongTextToFileLen
|
||||
) {
|
||||
// if the text is too long, we will convert it to a file
|
||||
extraContext.addItems([
|
||||
{
|
||||
type: 'context',
|
||||
name: 'Pasted Content',
|
||||
content: text,
|
||||
},
|
||||
]);
|
||||
e.preventDefault();
|
||||
return;
|
||||
}
|
||||
|
||||
// if a file is pasted, we will handle it here
|
||||
const files = Array.from(e.clipboardData.items)
|
||||
.filter((item) => item.kind === 'file')
|
||||
.map((item) => item.getAsFile())
|
||||
.filter((file) => file !== null);
|
||||
|
||||
if (files.length > 0) {
|
||||
e.preventDefault();
|
||||
extraContext.onFileAdded(files);
|
||||
}
|
||||
}}
|
||||
{...getRootProps()}
|
||||
>
|
||||
{!isGenerating && (
|
||||
<ChatInputExtraContextItem
|
||||
items={extraContext.items}
|
||||
removeItem={extraContext.removeItem}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex flex-row w-full">
|
||||
<textarea
|
||||
// Default (mobile): Enable vertical resize, overflow auto for scrolling if needed
|
||||
// Large screens (lg:): Disable manual resize, apply max-height for autosize limit
|
||||
className="text-md outline-none border-none w-full resize-vertical lg:resize-none lg:max-h-48 lg:overflow-y-auto" // Adjust lg:max-h-48 as needed (e.g., lg:max-h-60)
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
ref={textarea.ref}
|
||||
onInput={textarea.onInput} // Hook's input handler (will only resize height on lg+ screens)
|
||||
onKeyDown={(e) => {
|
||||
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
onSend();
|
||||
}
|
||||
}}
|
||||
id="msg-input"
|
||||
dir="auto"
|
||||
// Set a base height of 2 rows for mobile views
|
||||
// On lg+ screens, the hook will calculate and set the initial height anyway
|
||||
rows={2}
|
||||
></textarea>
|
||||
|
||||
{/* buttons area */}
|
||||
<div className="flex flex-row gap-2 ml-2">
|
||||
<label
|
||||
htmlFor="file-upload"
|
||||
className={classNames({
|
||||
'btn w-8 h-8 p-0 rounded-full': true,
|
||||
'btn-disabled': isGenerating,
|
||||
})}
|
||||
aria-label="Upload file"
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
>
|
||||
<PaperClipIcon className="h-5 w-5" />
|
||||
</label>
|
||||
<input
|
||||
id="file-upload"
|
||||
type="file"
|
||||
disabled={isGenerating}
|
||||
{...getInputProps()}
|
||||
hidden
|
||||
/>
|
||||
{isGenerating ? (
|
||||
<button
|
||||
className="btn btn-neutral w-8 h-8 p-0 rounded-full"
|
||||
onClick={onStop}
|
||||
>
|
||||
<StopIcon className="h-5 w-5" />
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn-primary w-8 h-8 p-0 rounded-full"
|
||||
onClick={onSend}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-5 w-5" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Dropzone>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
import { useEffect, useState } from 'react';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { classNames } from '../utils/misc';
|
||||
import daisyuiThemes from 'daisyui/theme/object';
|
||||
import { THEMES } from '../Config';
|
||||
import {
|
||||
Cog8ToothIcon,
|
||||
MoonIcon,
|
||||
Bars3Icon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
|
||||
export default function Header() {
|
||||
const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme());
|
||||
const { setShowSettings } = useAppContext();
|
||||
|
||||
const setTheme = (theme: string) => {
|
||||
StorageUtils.setTheme(theme);
|
||||
setSelectedTheme(theme);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
document.body.setAttribute('data-theme', selectedTheme);
|
||||
document.body.setAttribute(
|
||||
'data-color-scheme',
|
||||
daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto'
|
||||
);
|
||||
}, [selectedTheme]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-row items-center pt-6 pb-6 sticky top-0 z-10 bg-base-100">
|
||||
{/* open sidebar button */}
|
||||
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
|
||||
<Bars3Icon className="h-5 w-5" />
|
||||
</label>
|
||||
|
||||
<div className="grow text-2xl font-bold ml-2">llama.cpp</div>
|
||||
|
||||
{/* action buttons (top right) */}
|
||||
<div className="flex items-center">
|
||||
<div
|
||||
className="tooltip tooltip-bottom"
|
||||
data-tip="Settings"
|
||||
onClick={() => setShowSettings(true)}
|
||||
>
|
||||
<button className="btn" aria-hidden={true}>
|
||||
{/* settings button */}
|
||||
<Cog8ToothIcon className="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* theme controller is copied from https://daisyui.com/components/theme-controller/ */}
|
||||
<div className="tooltip tooltip-bottom" data-tip="Themes">
|
||||
<div className="dropdown dropdown-end dropdown-bottom">
|
||||
<div tabIndex={0} role="button" className="btn m-1">
|
||||
<MoonIcon className="w-5 h-5" />
|
||||
</div>
|
||||
<ul
|
||||
tabIndex={0}
|
||||
className="dropdown-content bg-base-300 rounded-box z-[1] w-52 p-2 shadow-2xl h-80 overflow-y-auto"
|
||||
>
|
||||
<li>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-block btn-ghost justify-start': true,
|
||||
'btn-active': selectedTheme === 'auto',
|
||||
})}
|
||||
onClick={() => setTheme('auto')}
|
||||
>
|
||||
auto
|
||||
</button>
|
||||
</li>
|
||||
{THEMES.map((theme) => (
|
||||
<li key={theme}>
|
||||
<input
|
||||
type="radio"
|
||||
name="theme-dropdown"
|
||||
className="theme-controller btn btn-sm btn-block btn-ghost justify-start"
|
||||
aria-label={theme}
|
||||
value={theme}
|
||||
checked={selectedTheme === theme}
|
||||
onChange={(e) => e.target.checked && setTheme(theme)}
|
||||
/>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,317 +0,0 @@
|
|||
import React, { useMemo, useState } from 'react';
|
||||
import Markdown, { ExtraProps } from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import rehypeHightlight from 'rehype-highlight';
|
||||
import rehypeKatex from 'rehype-katex';
|
||||
import remarkMath from 'remark-math';
|
||||
import remarkBreaks from 'remark-breaks';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { classNames, copyStr } from '../utils/misc';
|
||||
import { ElementContent, Root } from 'hast';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { CanvasType } from '../utils/types';
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
import { DocumentDuplicateIcon, PlayIcon } from '@heroicons/react/24/outline';
|
||||
|
||||
export default function MarkdownDisplay({
|
||||
content,
|
||||
isGenerating,
|
||||
}: {
|
||||
content: string;
|
||||
isGenerating?: boolean;
|
||||
}) {
|
||||
const preprocessedContent = useMemo(
|
||||
() => preprocessLaTeX(content),
|
||||
[content]
|
||||
);
|
||||
return (
|
||||
<Markdown
|
||||
remarkPlugins={[remarkGfm, remarkMath, remarkBreaks]}
|
||||
rehypePlugins={[rehypeHightlight, rehypeKatex, rehypeCustomCopyButton]}
|
||||
components={{
|
||||
button: (props) => (
|
||||
<CodeBlockButtons
|
||||
{...props}
|
||||
isGenerating={isGenerating}
|
||||
origContent={preprocessedContent}
|
||||
/>
|
||||
),
|
||||
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
|
||||
}}
|
||||
>
|
||||
{preprocessedContent}
|
||||
</Markdown>
|
||||
);
|
||||
}
|
||||
|
||||
const CodeBlockButtons: React.ElementType<
|
||||
React.ClassAttributes<HTMLButtonElement> &
|
||||
React.HTMLAttributes<HTMLButtonElement> &
|
||||
ExtraProps & { origContent: string; isGenerating?: boolean }
|
||||
> = ({ node, origContent, isGenerating }) => {
|
||||
const { config } = useAppContext();
|
||||
const startOffset = node?.position?.start.offset ?? 0;
|
||||
const endOffset = node?.position?.end.offset ?? 0;
|
||||
|
||||
const copiedContent = useMemo(
|
||||
() =>
|
||||
origContent
|
||||
.substring(startOffset, endOffset)
|
||||
.replace(/^```[^\n]+\n/g, '')
|
||||
.replace(/```$/g, ''),
|
||||
[origContent, startOffset, endOffset]
|
||||
);
|
||||
|
||||
const codeLanguage = useMemo(
|
||||
() =>
|
||||
origContent
|
||||
.substring(startOffset, startOffset + 10)
|
||||
.match(/^```([^\n]+)\n/)?.[1] ?? '',
|
||||
[origContent, startOffset]
|
||||
);
|
||||
|
||||
const canRunCode =
|
||||
!isGenerating &&
|
||||
config.pyIntepreterEnabled &&
|
||||
codeLanguage.startsWith('py');
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames({
|
||||
'text-right sticky top-[7em] mb-2 mr-2 h-0': true,
|
||||
'display-none': !node?.position,
|
||||
})}
|
||||
>
|
||||
<CopyButton
|
||||
className="badge btn-mini btn-soft shadow-sm"
|
||||
content={copiedContent}
|
||||
/>
|
||||
{canRunCode && (
|
||||
<RunPyCodeButton
|
||||
className="badge btn-mini shadow-sm ml-2"
|
||||
content={copiedContent}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const CopyButton = ({
|
||||
content,
|
||||
className,
|
||||
}: {
|
||||
content: string;
|
||||
className?: string;
|
||||
}) => {
|
||||
const [copied, setCopied] = useState(false);
|
||||
return (
|
||||
<BtnWithTooltips
|
||||
className={className}
|
||||
onClick={() => {
|
||||
copyStr(content);
|
||||
setCopied(true);
|
||||
}}
|
||||
onMouseLeave={() => setCopied(false)}
|
||||
tooltipsContent={copied ? 'Copied!' : 'Copy'}
|
||||
>
|
||||
<DocumentDuplicateIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
);
|
||||
};
|
||||
|
||||
export const RunPyCodeButton = ({
|
||||
content,
|
||||
className,
|
||||
}: {
|
||||
content: string;
|
||||
className?: string;
|
||||
}) => {
|
||||
const { setCanvasData } = useAppContext();
|
||||
return (
|
||||
<>
|
||||
<BtnWithTooltips
|
||||
className={className}
|
||||
onClick={() =>
|
||||
setCanvasData({
|
||||
type: CanvasType.PY_INTERPRETER,
|
||||
content,
|
||||
})
|
||||
}
|
||||
tooltipsContent="Run code"
|
||||
>
|
||||
<PlayIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* This injects the "button" element before each "pre" element.
|
||||
* The actual button will be replaced with a react component in the MarkdownDisplay.
|
||||
* We don't replace "pre" node directly because it will cause the node to re-render, which causes this bug: https://github.com/ggerganov/llama.cpp/issues/9608
|
||||
*/
|
||||
function rehypeCustomCopyButton() {
|
||||
return function (tree: Root) {
|
||||
visit(tree, 'element', function (node) {
|
||||
if (node.tagName === 'pre' && !node.properties.visited) {
|
||||
const preNode = { ...node };
|
||||
// replace current node
|
||||
preNode.properties.visited = 'true';
|
||||
node.tagName = 'div';
|
||||
node.properties = {};
|
||||
// add node for button
|
||||
const btnNode: ElementContent = {
|
||||
type: 'element',
|
||||
tagName: 'button',
|
||||
properties: {},
|
||||
children: [],
|
||||
position: node.position,
|
||||
};
|
||||
node.children = [btnNode, preNode];
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* The part below is copied and adapted from:
|
||||
* https://github.com/danny-avila/LibreChat/blob/main/client/src/utils/latex.ts
|
||||
* (MIT License)
|
||||
*/
|
||||
|
||||
// Regex to check if the processed content contains any potential LaTeX patterns
|
||||
const containsLatexRegex =
|
||||
/\\\(.*?\\\)|\\\[.*?\\\]|\$.*?\$|\\begin\{equation\}.*?\\end\{equation\}/;
|
||||
|
||||
// Regex for inline and block LaTeX expressions
|
||||
const inlineLatex = new RegExp(/\\\((.+?)\\\)/, 'g');
|
||||
const blockLatex = new RegExp(/\\\[(.*?[^\\])\\\]/, 'gs');
|
||||
|
||||
// Function to restore code blocks
|
||||
const restoreCodeBlocks = (content: string, codeBlocks: string[]) => {
|
||||
return content.replace(
|
||||
/<<CODE_BLOCK_(\d+)>>/g,
|
||||
(_, index) => codeBlocks[index]
|
||||
);
|
||||
};
|
||||
|
||||
// Regex to identify code blocks and inline code
|
||||
const codeBlockRegex = /(```[\s\S]*?```|`.*?`)/g;
|
||||
|
||||
export const processLaTeX = (_content: string) => {
|
||||
let content = _content;
|
||||
// Temporarily replace code blocks and inline code with placeholders
|
||||
const codeBlocks: string[] = [];
|
||||
let index = 0;
|
||||
content = content.replace(codeBlockRegex, (match) => {
|
||||
codeBlocks[index] = match;
|
||||
return `<<CODE_BLOCK_${index++}>>`;
|
||||
});
|
||||
|
||||
// Escape dollar signs followed by a digit or space and digit
|
||||
let processedContent = content.replace(/(\$)(?=\s?\d)/g, '\\$');
|
||||
|
||||
// If no LaTeX patterns are found, restore code blocks and return the processed content
|
||||
if (!containsLatexRegex.test(processedContent)) {
|
||||
return restoreCodeBlocks(processedContent, codeBlocks);
|
||||
}
|
||||
|
||||
// Convert LaTeX expressions to a markdown compatible format
|
||||
processedContent = processedContent
|
||||
.replace(inlineLatex, (_: string, equation: string) => `$${equation}$`) // Convert inline LaTeX
|
||||
.replace(blockLatex, (_: string, equation: string) => `$$${equation}$$`); // Convert block LaTeX
|
||||
|
||||
// Restore code blocks
|
||||
return restoreCodeBlocks(processedContent, codeBlocks);
|
||||
};
|
||||
|
||||
/**
|
||||
* Preprocesses LaTeX content by replacing delimiters and escaping certain characters.
|
||||
*
|
||||
* @param content The input string containing LaTeX expressions.
|
||||
* @returns The processed string with replaced delimiters and escaped characters.
|
||||
*/
|
||||
export function preprocessLaTeX(content: string): string {
|
||||
// Step 1: Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
content = content.replace(/(```[\s\S]*?```|`[^`\n]+`)/g, (_, code) => {
|
||||
codeBlocks.push(code);
|
||||
return `<<CODE_BLOCK_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Step 2: Protect existing LaTeX expressions
|
||||
const latexExpressions: string[] = [];
|
||||
|
||||
// Protect block math ($$...$$), \[...\], and \(...\) as before.
|
||||
content = content.replace(
|
||||
/(\$\$[\s\S]*?\$\$|\\\[[\s\S]*?\\\]|\\\(.*?\\\))/g,
|
||||
(match) => {
|
||||
latexExpressions.push(match);
|
||||
return `<<LATEX_${latexExpressions.length - 1}>>`;
|
||||
}
|
||||
);
|
||||
|
||||
// Protect inline math ($...$) only if it does NOT match a currency pattern.
|
||||
// We assume a currency pattern is one where the inner content is purely numeric (with optional decimals).
|
||||
content = content.replace(/\$([^$]+)\$/g, (match, inner) => {
|
||||
if (/^\s*\d+(?:\.\d+)?\s*$/.test(inner)) {
|
||||
// This looks like a currency value (e.g. "$123" or "$12.34"),
|
||||
// so don't protect it.
|
||||
return match;
|
||||
} else {
|
||||
// Otherwise, treat it as a LaTeX expression.
|
||||
latexExpressions.push(match);
|
||||
return `<<LATEX_${latexExpressions.length - 1}>>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Step 3: Escape dollar signs that are likely currency indicators.
|
||||
// (Now that inline math is protected, this will only escape dollars not already protected)
|
||||
content = content.replace(/\$(?=\d)/g, '\\$');
|
||||
|
||||
// Step 4: Restore LaTeX expressions
|
||||
content = content.replace(
|
||||
/<<LATEX_(\d+)>>/g,
|
||||
(_, index) => latexExpressions[parseInt(index)]
|
||||
);
|
||||
|
||||
// Step 5: Restore code blocks
|
||||
content = content.replace(
|
||||
/<<CODE_BLOCK_(\d+)>>/g,
|
||||
(_, index) => codeBlocks[parseInt(index)]
|
||||
);
|
||||
|
||||
// Step 6: Apply additional escaping functions
|
||||
content = escapeBrackets(content);
|
||||
content = escapeMhchem(content);
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
export function escapeBrackets(text: string): string {
|
||||
const pattern =
|
||||
/(```[\S\s]*?```|`.*?`)|\\\[([\S\s]*?[^\\])\\]|\\\((.*?)\\\)/g;
|
||||
return text.replace(
|
||||
pattern,
|
||||
(
|
||||
match: string,
|
||||
codeBlock: string | undefined,
|
||||
squareBracket: string | undefined,
|
||||
roundBracket: string | undefined
|
||||
): string => {
|
||||
if (codeBlock != null) {
|
||||
return codeBlock;
|
||||
} else if (squareBracket != null) {
|
||||
return `$$${squareBracket}$$`;
|
||||
} else if (roundBracket != null) {
|
||||
return `$${roundBracket}$`;
|
||||
}
|
||||
return match;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function escapeMhchem(text: string) {
|
||||
return text.replaceAll('$\\ce{', '$\\\\ce{').replaceAll('$\\pu{', '$\\\\pu{');
|
||||
}
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
import React, { createContext, useState, useContext } from 'react';
|
||||
|
||||
type ModalContextType = {
|
||||
showConfirm: (message: string) => Promise<boolean>;
|
||||
showPrompt: (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
) => Promise<string | undefined>;
|
||||
showAlert: (message: string) => Promise<void>;
|
||||
};
|
||||
const ModalContext = createContext<ModalContextType>(null!);
|
||||
|
||||
interface ModalState<T> {
|
||||
isOpen: boolean;
|
||||
message: string;
|
||||
defaultValue?: string;
|
||||
resolve: ((value: T) => void) | null;
|
||||
}
|
||||
|
||||
export function ModalProvider({ children }: { children: React.ReactNode }) {
|
||||
const [confirmState, setConfirmState] = useState<ModalState<boolean>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const [promptState, setPromptState] = useState<
|
||||
ModalState<string | undefined>
|
||||
>({ isOpen: false, message: '', resolve: null });
|
||||
const [alertState, setAlertState] = useState<ModalState<void>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
const showConfirm = (message: string): Promise<boolean> => {
|
||||
return new Promise((resolve) => {
|
||||
setConfirmState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showPrompt = (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
): Promise<string | undefined> => {
|
||||
return new Promise((resolve) => {
|
||||
setPromptState({ isOpen: true, message, defaultValue, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showAlert = (message: string): Promise<void> => {
|
||||
return new Promise((resolve) => {
|
||||
setAlertState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const handleConfirm = (result: boolean) => {
|
||||
confirmState.resolve?.(result);
|
||||
setConfirmState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handlePrompt = (result?: string) => {
|
||||
promptState.resolve?.(result);
|
||||
setPromptState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handleAlertClose = () => {
|
||||
alertState.resolve?.();
|
||||
setAlertState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
return (
|
||||
<ModalContext.Provider value={{ showConfirm, showPrompt, showAlert }}>
|
||||
{children}
|
||||
|
||||
{/* Confirm Modal */}
|
||||
{confirmState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{confirmState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button
|
||||
className="btn btn-ghost"
|
||||
onClick={() => handleConfirm(false)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-error"
|
||||
onClick={() => handleConfirm(true)}
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Prompt Modal */}
|
||||
{promptState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{promptState.message}</h3>
|
||||
<input
|
||||
type="text"
|
||||
className="input input-bordered w-full mt-2"
|
||||
defaultValue={promptState.defaultValue}
|
||||
ref={inputRef}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
handlePrompt((e.target as HTMLInputElement).value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<div className="modal-action">
|
||||
<button className="btn btn-ghost" onClick={() => handlePrompt()}>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
onClick={() => handlePrompt(inputRef.current?.value)}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Alert Modal */}
|
||||
{alertState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{alertState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button className="btn" onClick={handleAlertClose}>
|
||||
OK
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
</ModalContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useModals() {
|
||||
const context = useContext(ModalContext);
|
||||
if (!context) throw new Error('useModals must be used within ModalProvider');
|
||||
return context;
|
||||
}
|
||||
|
|
@ -1,553 +0,0 @@
|
|||
import { useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
|
||||
import { isDev } from '../Config';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { classNames, isBoolean, isNumeric, isString } from '../utils/misc';
|
||||
import {
|
||||
BeakerIcon,
|
||||
ChatBubbleOvalLeftEllipsisIcon,
|
||||
Cog6ToothIcon,
|
||||
FunnelIcon,
|
||||
HandRaisedIcon,
|
||||
SquaresPlusIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { OpenInNewTab } from '../utils/common';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
type SettKey = keyof typeof CONFIG_DEFAULT;
|
||||
|
||||
const BASIC_KEYS: SettKey[] = [
|
||||
'temperature',
|
||||
'top_k',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'max_tokens',
|
||||
];
|
||||
const SAMPLER_KEYS: SettKey[] = [
|
||||
'dynatemp_range',
|
||||
'dynatemp_exponent',
|
||||
'typical_p',
|
||||
'xtc_probability',
|
||||
'xtc_threshold',
|
||||
];
|
||||
const PENALTY_KEYS: SettKey[] = [
|
||||
'repeat_last_n',
|
||||
'repeat_penalty',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_penalty_last_n',
|
||||
];
|
||||
|
||||
enum SettingInputType {
|
||||
SHORT_INPUT,
|
||||
LONG_INPUT,
|
||||
CHECKBOX,
|
||||
CUSTOM,
|
||||
}
|
||||
|
||||
interface SettingFieldInput {
|
||||
type: Exclude<SettingInputType, SettingInputType.CUSTOM>;
|
||||
label: string | React.ReactElement;
|
||||
help?: string | React.ReactElement;
|
||||
key: SettKey;
|
||||
}
|
||||
|
||||
interface SettingFieldCustom {
|
||||
type: SettingInputType.CUSTOM;
|
||||
key: SettKey;
|
||||
component:
|
||||
| string
|
||||
| React.FC<{
|
||||
value: string | boolean | number;
|
||||
onChange: (value: string) => void;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface SettingSection {
|
||||
title: React.ReactElement;
|
||||
fields: (SettingFieldInput | SettingFieldCustom)[];
|
||||
}
|
||||
|
||||
const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline';
|
||||
|
||||
const SETTING_SECTIONS: SettingSection[] = [
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<Cog6ToothIcon className={ICON_CLASSNAME} />
|
||||
General
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'API Key',
|
||||
key: 'apiKey',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.LONG_INPUT,
|
||||
label: 'System Message (will be disabled if left empty)',
|
||||
key: 'systemMessage',
|
||||
},
|
||||
...BASIC_KEYS.map(
|
||||
(key) =>
|
||||
({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'Paste length to file',
|
||||
key: 'pasteLongTextToFileLen',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Parse PDF as image instead of text',
|
||||
key: 'pdfAsImage',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<FunnelIcon className={ICON_CLASSNAME} />
|
||||
Samplers
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'Samplers queue',
|
||||
key: 'samplers',
|
||||
},
|
||||
...SAMPLER_KEYS.map(
|
||||
(key) =>
|
||||
({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<HandRaisedIcon className={ICON_CLASSNAME} />
|
||||
Penalties
|
||||
</>
|
||||
),
|
||||
fields: PENALTY_KEYS.map((key) => ({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
})),
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<ChatBubbleOvalLeftEllipsisIcon className={ICON_CLASSNAME} />
|
||||
Reasoning
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Expand thought process by default when generating messages',
|
||||
key: 'showThoughtInProgress',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label:
|
||||
'Exclude thought process when sending requests to API (Recommended for DeepSeek-R1)',
|
||||
key: 'excludeThoughtOnReq',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<SquaresPlusIcon className={ICON_CLASSNAME} />
|
||||
Advanced
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CUSTOM,
|
||||
key: 'custom', // dummy key, won't be used
|
||||
component: () => {
|
||||
const debugImportDemoConv = async () => {
|
||||
const res = await fetch('/demo-conversation.json');
|
||||
const demoConv = await res.json();
|
||||
StorageUtils.remove(demoConv.id);
|
||||
for (const msg of demoConv.messages) {
|
||||
StorageUtils.appendMsg(demoConv.id, msg);
|
||||
}
|
||||
};
|
||||
return (
|
||||
<button className="btn" onClick={debugImportDemoConv}>
|
||||
(debug) Import demo conversation
|
||||
</button>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Show tokens per second',
|
||||
key: 'showTokensPerSecond',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.LONG_INPUT,
|
||||
label: (
|
||||
<>
|
||||
Custom JSON config (For more info, refer to{' '}
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/blob/master/tools/server/README.md">
|
||||
server documentation
|
||||
</OpenInNewTab>
|
||||
)
|
||||
</>
|
||||
),
|
||||
key: 'custom',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<BeakerIcon className={ICON_CLASSNAME} />
|
||||
Experimental
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CUSTOM,
|
||||
key: 'custom', // dummy key, won't be used
|
||||
component: () => (
|
||||
<>
|
||||
<p className="mb-8">
|
||||
Experimental features are not guaranteed to work correctly.
|
||||
<br />
|
||||
<br />
|
||||
If you encounter any problems, create a{' '}
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/new?template=019-bug-misc.yml">
|
||||
Bug (misc.)
|
||||
</OpenInNewTab>{' '}
|
||||
report on Github. Please also specify <b>webui/experimental</b> on
|
||||
the report title and include screenshots.
|
||||
<br />
|
||||
<br />
|
||||
Some features may require packages downloaded from CDN, so they
|
||||
need internet connection.
|
||||
</p>
|
||||
</>
|
||||
),
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: (
|
||||
<>
|
||||
<b>Enable Python interpreter</b>
|
||||
<br />
|
||||
<small className="text-xs">
|
||||
This feature uses{' '}
|
||||
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
|
||||
downloaded from CDN. To use this feature, ask the LLM to generate
|
||||
Python code inside a Markdown code block. You will see a "Run"
|
||||
button on the code block, near the "Copy" button.
|
||||
</small>
|
||||
</>
|
||||
),
|
||||
key: 'pyIntepreterEnabled',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export default function SettingDialog({
|
||||
show,
|
||||
onClose,
|
||||
}: {
|
||||
show: boolean;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
const { config, saveConfig } = useAppContext();
|
||||
const [sectionIdx, setSectionIdx] = useState(0);
|
||||
|
||||
// clone the config object to prevent direct mutation
|
||||
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
|
||||
JSON.parse(JSON.stringify(config))
|
||||
);
|
||||
const { showConfirm, showAlert } = useModals();
|
||||
|
||||
const resetConfig = async () => {
|
||||
if (await showConfirm('Are you sure you want to reset all settings?')) {
|
||||
setLocalConfig(CONFIG_DEFAULT);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
// copy the local config to prevent direct mutation
|
||||
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
|
||||
JSON.stringify(localConfig)
|
||||
);
|
||||
// validate the config
|
||||
for (const key in newConfig) {
|
||||
const value = newConfig[key as SettKey];
|
||||
const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]);
|
||||
const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]);
|
||||
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
|
||||
if (mustBeString) {
|
||||
if (!isString(value)) {
|
||||
await showAlert(`Value for ${key} must be string`);
|
||||
return;
|
||||
}
|
||||
} else if (mustBeNumeric) {
|
||||
const trimmedValue = value.toString().trim();
|
||||
const numVal = Number(trimmedValue);
|
||||
if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) {
|
||||
await showAlert(`Value for ${key} must be numeric`);
|
||||
return;
|
||||
}
|
||||
// force conversion to number
|
||||
// @ts-expect-error this is safe
|
||||
newConfig[key] = numVal;
|
||||
} else if (mustBeBoolean) {
|
||||
if (!isBoolean(value)) {
|
||||
await showAlert(`Value for ${key} must be boolean`);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
console.error(`Unknown default type for key ${key}`);
|
||||
}
|
||||
}
|
||||
if (isDev) console.log('Saving config', newConfig);
|
||||
saveConfig(newConfig);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const onChange = (key: SettKey) => (value: string | boolean) => {
|
||||
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
|
||||
setLocalConfig({ ...localConfig, [key]: value });
|
||||
};
|
||||
|
||||
return (
|
||||
<dialog
|
||||
className={classNames({ modal: true, 'modal-open': show })}
|
||||
aria-label="Settings dialog"
|
||||
>
|
||||
<div className="modal-box w-11/12 max-w-3xl">
|
||||
<h3 className="text-lg font-bold mb-6">Settings</h3>
|
||||
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
|
||||
{/* Left panel, showing sections - Desktop version */}
|
||||
<div
|
||||
className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200"
|
||||
role="complementary"
|
||||
aria-description="Settings sections"
|
||||
tabIndex={0}
|
||||
>
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<button
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
|
||||
'btn-active': sectionIdx === idx,
|
||||
})}
|
||||
onClick={() => setSectionIdx(idx)}
|
||||
dir="auto"
|
||||
>
|
||||
{section.title}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Left panel, showing sections - Mobile version */}
|
||||
{/* This menu is skipped on a11y, otherwise it's repeated the desktop version */}
|
||||
<div
|
||||
className="md:hidden flex flex-row gap-2 mb-4"
|
||||
aria-disabled={true}
|
||||
>
|
||||
<details className="dropdown">
|
||||
<summary className="btn bt-sm w-full m-1">
|
||||
{SETTING_SECTIONS[sectionIdx].title}
|
||||
</summary>
|
||||
<ul className="menu dropdown-content bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal': true,
|
||||
'btn-active': sectionIdx === idx,
|
||||
})}
|
||||
onClick={() => setSectionIdx(idx)}
|
||||
dir="auto"
|
||||
>
|
||||
{section.title}
|
||||
</div>
|
||||
))}
|
||||
</ul>
|
||||
</details>
|
||||
</div>
|
||||
|
||||
{/* Right panel, showing setting fields */}
|
||||
<div className="grow overflow-y-auto px-4">
|
||||
{SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => {
|
||||
const key = `${sectionIdx}-${idx}`;
|
||||
if (field.type === SettingInputType.SHORT_INPUT) {
|
||||
return (
|
||||
<SettingsModalShortInput
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={localConfig[field.key]}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.LONG_INPUT) {
|
||||
return (
|
||||
<SettingsModalLongInput
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={localConfig[field.key].toString()}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.CHECKBOX) {
|
||||
return (
|
||||
<SettingsModalCheckbox
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={!!localConfig[field.key]}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.CUSTOM) {
|
||||
return (
|
||||
<div key={key} className="mb-2">
|
||||
{typeof field.component === 'string'
|
||||
? field.component
|
||||
: field.component({
|
||||
value: localConfig[field.key],
|
||||
onChange: onChange(field.key),
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
|
||||
<p className="opacity-40 mb-6 text-sm mt-8">
|
||||
Settings are saved in browser's localStorage
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="modal-action">
|
||||
<button className="btn" onClick={resetConfig}>
|
||||
Reset to default
|
||||
</button>
|
||||
<button className="btn" onClick={onClose}>
|
||||
Close
|
||||
</button>
|
||||
<button className="btn btn-primary" onClick={handleSave}>
|
||||
Save
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalLongInput({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
label?: string;
|
||||
}) {
|
||||
return (
|
||||
<label className="form-control">
|
||||
<div className="label inline text-sm">{label || configKey}</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-24 mb-2"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalShortInput({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
value: any;
|
||||
onChange: (value: string) => void;
|
||||
label?: string;
|
||||
}) {
|
||||
const helpMsg = CONFIG_INFO[configKey];
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* on mobile, we simply show the help message here */}
|
||||
{helpMsg && (
|
||||
<div className="block mb-1 opacity-75">
|
||||
<p className="text-xs">{helpMsg}</p>
|
||||
</div>
|
||||
)}
|
||||
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
|
||||
<div className="dropdown dropdown-hover">
|
||||
<div tabIndex={0} role="button" className="font-bold hidden md:block">
|
||||
{label || configKey}
|
||||
</div>
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
className="grow"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalCheckbox({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
value: boolean;
|
||||
onChange: (value: boolean) => void;
|
||||
label: string;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="toggle"
|
||||
checked={value}
|
||||
onChange={(e) => onChange(e.target.checked)}
|
||||
/>
|
||||
<span className="ml-4">{label || configKey}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,369 +0,0 @@
|
|||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
import { Conversation } from '../utils/types';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useNavigate, useParams } from 'react-router';
|
||||
import {
|
||||
ArrowDownTrayIcon,
|
||||
EllipsisVerticalIcon,
|
||||
PencilIcon,
|
||||
PencilSquareIcon,
|
||||
TrashIcon,
|
||||
XMarkIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
const navigate = useNavigate();
|
||||
|
||||
const { isGenerating } = useAppContext();
|
||||
|
||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||
const [currConv, setCurrConv] = useState<Conversation | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
|
||||
}, [params.convId]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleConversationChange = async () => {
|
||||
setConversations(await StorageUtils.getAllConversations());
|
||||
};
|
||||
StorageUtils.onConversationChanged(handleConversationChange);
|
||||
handleConversationChange();
|
||||
return () => {
|
||||
StorageUtils.offConversationChanged(handleConversationChange);
|
||||
};
|
||||
}, []);
|
||||
const { showConfirm, showPrompt } = useModals();
|
||||
|
||||
const groupedConv = useMemo(
|
||||
() => groupConversationsByDate(conversations),
|
||||
[conversations]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
id="toggle-drawer"
|
||||
type="checkbox"
|
||||
className="drawer-toggle"
|
||||
aria-label="Toggle sidebar"
|
||||
defaultChecked
|
||||
/>
|
||||
|
||||
<div
|
||||
className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64"
|
||||
role="complementary"
|
||||
aria-label="Sidebar"
|
||||
tabIndex={0}
|
||||
>
|
||||
<label
|
||||
htmlFor="toggle-drawer"
|
||||
aria-label="Close sidebar"
|
||||
className="drawer-overlay"
|
||||
></label>
|
||||
|
||||
<a
|
||||
href="#main-scroll"
|
||||
className="absolute -left-80 top-0 w-1 h-1 overflow-hidden"
|
||||
>
|
||||
Skip to main content
|
||||
</a>
|
||||
|
||||
<div className="flex flex-col bg-base-200 min-h-full max-w-64 py-4 px-4">
|
||||
<div className="flex flex-row items-center justify-between mb-4 mt-4">
|
||||
<h2 className="font-bold ml-4" role="heading">
|
||||
Conversations
|
||||
</h2>
|
||||
|
||||
{/* close sidebar button */}
|
||||
<label
|
||||
htmlFor="toggle-drawer"
|
||||
className="btn btn-ghost lg:hidden"
|
||||
aria-label="Close sidebar"
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<XMarkIcon className="w-5 h-5" />
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{/* new conversation button */}
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start px-2': true,
|
||||
'btn-soft': !currConv,
|
||||
})}
|
||||
onClick={() => navigate('/')}
|
||||
aria-label="New conversation"
|
||||
>
|
||||
<PencilSquareIcon className="w-5 h-5" />
|
||||
New conversation
|
||||
</button>
|
||||
|
||||
{/* list of conversations */}
|
||||
{groupedConv.map((group, i) => (
|
||||
<div key={i} role="group">
|
||||
{/* group name (by date) */}
|
||||
{group.title ? (
|
||||
// we use btn class here to make sure that the padding/margin are aligned with the other items
|
||||
<b
|
||||
className="btn btn-ghost btn-xs bg-none btn-disabled block text-xs text-base-content text-start px-2 mb-0 mt-6 font-bold"
|
||||
role="note"
|
||||
aria-description={group.title}
|
||||
tabIndex={0}
|
||||
>
|
||||
{group.title}
|
||||
</b>
|
||||
) : (
|
||||
<div className="h-2" />
|
||||
)}
|
||||
|
||||
{group.conversations.map((conv) => (
|
||||
<ConversationItem
|
||||
key={conv.id}
|
||||
conv={conv}
|
||||
isCurrConv={currConv?.id === conv.id}
|
||||
onSelect={() => {
|
||||
navigate(`/chat/${conv.id}`);
|
||||
}}
|
||||
onDelete={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot delete conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (
|
||||
await showConfirm(
|
||||
'Are you sure to delete this conversation?'
|
||||
)
|
||||
) {
|
||||
toast.success('Conversation deleted');
|
||||
StorageUtils.remove(conv.id);
|
||||
navigate('/');
|
||||
}
|
||||
}}
|
||||
onDownload={() => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot download conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
const conversationJson = JSON.stringify(conv, null, 2);
|
||||
const blob = new Blob([conversationJson], {
|
||||
type: 'application/json',
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `conversation_${conv.id}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}}
|
||||
onRename={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot rename conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
const newName = await showPrompt(
|
||||
'Enter new name for the conversation',
|
||||
conv.name
|
||||
);
|
||||
if (newName && newName.trim().length > 0) {
|
||||
StorageUtils.updateConversationName(conv.id, newName);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
<div className="text-center text-xs opacity-40 mt-auto mx-4 pt-8">
|
||||
Conversations are saved to browser's IndexedDB
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function ConversationItem({
|
||||
conv,
|
||||
isCurrConv,
|
||||
onSelect,
|
||||
onDelete,
|
||||
onDownload,
|
||||
onRename,
|
||||
}: {
|
||||
conv: Conversation;
|
||||
isCurrConv: boolean;
|
||||
onSelect: () => void;
|
||||
onDelete: () => void;
|
||||
onDownload: () => void;
|
||||
onRename: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="menuitem"
|
||||
tabIndex={0}
|
||||
aria-label={conv.name}
|
||||
className={classNames({
|
||||
'group flex flex-row btn btn-ghost justify-start items-center font-normal px-2 h-9':
|
||||
true,
|
||||
'btn-soft': isCurrConv,
|
||||
})}
|
||||
>
|
||||
<button
|
||||
key={conv.id}
|
||||
className="w-full overflow-hidden truncate text-start"
|
||||
onClick={onSelect}
|
||||
dir="auto"
|
||||
>
|
||||
{conv.name}
|
||||
</button>
|
||||
<div tabIndex={0} className="dropdown dropdown-end h-5">
|
||||
<BtnWithTooltips
|
||||
// on mobile, we always show the ellipsis icon
|
||||
// on desktop, we only show it when the user hovers over the conversation item
|
||||
// we use opacity instead of hidden to avoid layout shift
|
||||
className="cursor-pointer opacity-100 md:opacity-0 group-hover:opacity-100"
|
||||
onClick={() => {}}
|
||||
tooltipsContent="More"
|
||||
>
|
||||
<EllipsisVerticalIcon className="w-5 h-5" />
|
||||
</BtnWithTooltips>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
aria-label="More options"
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] p-2 shadow"
|
||||
>
|
||||
<li onClick={onRename} tabIndex={0}>
|
||||
<a>
|
||||
<PencilIcon className="w-4 h-4" />
|
||||
Rename
|
||||
</a>
|
||||
</li>
|
||||
<li onClick={onDownload} tabIndex={0}>
|
||||
<a>
|
||||
<ArrowDownTrayIcon className="w-4 h-4" />
|
||||
Download
|
||||
</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={onDelete} tabIndex={0}>
|
||||
<a>
|
||||
<TrashIcon className="w-4 h-4" />
|
||||
Delete
|
||||
</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
|
||||
export interface GroupedConversations {
|
||||
title?: string;
|
||||
conversations: Conversation[];
|
||||
}
|
||||
|
||||
// TODO @ngxson : add test for this function
|
||||
// Group conversations by date
|
||||
// - "Previous 7 Days"
|
||||
// - "Previous 30 Days"
|
||||
// - "Month Year" (e.g., "April 2023")
|
||||
export function groupConversationsByDate(
|
||||
conversations: Conversation[]
|
||||
): GroupedConversations[] {
|
||||
const now = new Date();
|
||||
const today = new Date(now.getFullYear(), now.getMonth(), now.getDate()); // Start of today
|
||||
|
||||
const sevenDaysAgo = new Date(today);
|
||||
sevenDaysAgo.setDate(today.getDate() - 7);
|
||||
|
||||
const thirtyDaysAgo = new Date(today);
|
||||
thirtyDaysAgo.setDate(today.getDate() - 30);
|
||||
|
||||
const groups: { [key: string]: Conversation[] } = {
|
||||
Today: [],
|
||||
'Previous 7 Days': [],
|
||||
'Previous 30 Days': [],
|
||||
};
|
||||
const monthlyGroups: { [key: string]: Conversation[] } = {}; // Key format: "Month Year" e.g., "April 2023"
|
||||
|
||||
// Sort conversations by lastModified date in descending order (newest first)
|
||||
// This helps when adding to groups, but the final output order of groups is fixed.
|
||||
const sortedConversations = [...conversations].sort(
|
||||
(a, b) => b.lastModified - a.lastModified
|
||||
);
|
||||
|
||||
for (const conv of sortedConversations) {
|
||||
const convDate = new Date(conv.lastModified);
|
||||
|
||||
if (convDate >= today) {
|
||||
groups['Today'].push(conv);
|
||||
} else if (convDate >= sevenDaysAgo) {
|
||||
groups['Previous 7 Days'].push(conv);
|
||||
} else if (convDate >= thirtyDaysAgo) {
|
||||
groups['Previous 30 Days'].push(conv);
|
||||
} else {
|
||||
const monthName = convDate.toLocaleString('default', { month: 'long' });
|
||||
const year = convDate.getFullYear();
|
||||
const monthYearKey = `${monthName} ${year}`;
|
||||
if (!monthlyGroups[monthYearKey]) {
|
||||
monthlyGroups[monthYearKey] = [];
|
||||
}
|
||||
monthlyGroups[monthYearKey].push(conv);
|
||||
}
|
||||
}
|
||||
|
||||
const result: GroupedConversations[] = [];
|
||||
|
||||
if (groups['Today'].length > 0) {
|
||||
result.push({
|
||||
title: undefined, // no title for Today
|
||||
conversations: groups['Today'],
|
||||
});
|
||||
}
|
||||
|
||||
if (groups['Previous 7 Days'].length > 0) {
|
||||
result.push({
|
||||
title: 'Previous 7 Days',
|
||||
conversations: groups['Previous 7 Days'],
|
||||
});
|
||||
}
|
||||
|
||||
if (groups['Previous 30 Days'].length > 0) {
|
||||
result.push({
|
||||
title: 'Previous 30 Days',
|
||||
conversations: groups['Previous 30 Days'],
|
||||
});
|
||||
}
|
||||
|
||||
// Sort monthly groups by date (most recent month first)
|
||||
const sortedMonthKeys = Object.keys(monthlyGroups).sort((a, b) => {
|
||||
const dateA = new Date(a); // "Month Year" can be parsed by Date constructor
|
||||
const dateB = new Date(b);
|
||||
return dateB.getTime() - dateA.getTime();
|
||||
});
|
||||
|
||||
for (const monthKey of sortedMonthKeys) {
|
||||
if (monthlyGroups[monthKey].length > 0) {
|
||||
result.push({ title: monthKey, conversations: monthlyGroups[monthKey] });
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -1,371 +0,0 @@
|
|||
import { useState } from 'react';
|
||||
import { MessageExtra } from '../utils/types';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import * as pdfjs from 'pdfjs-dist';
|
||||
import pdfjsWorkerSrc from 'pdfjs-dist/build/pdf.worker.min.mjs?url';
|
||||
import { TextContent, TextItem } from 'pdfjs-dist/types/src/display/api';
|
||||
|
||||
pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
|
||||
|
||||
// This file handles uploading extra context items (a.k.a files)
|
||||
// It allows processing these kinds of files:
|
||||
// - image files (converted to base64)
|
||||
// - audio files (converted to base64)
|
||||
// - text files (including code files)
|
||||
// - pdf (converted to text)
|
||||
|
||||
// Interface describing the API returned by the hook
|
||||
export interface ChatExtraContextApi {
|
||||
items?: MessageExtra[]; // undefined if empty, similar to Message['extra']
|
||||
addItems: (items: MessageExtra[]) => void;
|
||||
removeItem: (idx: number) => void;
|
||||
clearItems: () => void;
|
||||
onFileAdded: (files: File[]) => void; // used by "upload" button
|
||||
}
|
||||
|
||||
export function useChatExtraContext(): ChatExtraContextApi {
|
||||
const { serverProps, config } = useAppContext();
|
||||
const [items, setItems] = useState<MessageExtra[]>([]);
|
||||
|
||||
const addItems = (newItems: MessageExtra[]) => {
|
||||
setItems((prev) => [...prev, ...newItems]);
|
||||
};
|
||||
|
||||
const removeItem = (idx: number) => {
|
||||
setItems((prev) => prev.filter((_, i) => i !== idx));
|
||||
};
|
||||
|
||||
const clearItems = () => {
|
||||
setItems([]);
|
||||
};
|
||||
|
||||
const isSupportVision = serverProps?.modalities?.vision;
|
||||
|
||||
const onFileAdded = async (files: File[]) => {
|
||||
try {
|
||||
for (const file of files) {
|
||||
const mimeType = file.type;
|
||||
|
||||
// this limit is only to prevent accidental uploads of huge files
|
||||
// it can potentially crashes the browser because we read the file as base64
|
||||
if (file.size > 500 * 1024 * 1024) {
|
||||
toast.error('File is too large. Maximum size is 500MB.');
|
||||
break;
|
||||
}
|
||||
|
||||
if (mimeType.startsWith('image/')) {
|
||||
if (!isSupportVision) {
|
||||
toast.error('Multimodal is not supported by this server or model.');
|
||||
break;
|
||||
}
|
||||
|
||||
let base64Url = await getFileAsBase64(file);
|
||||
if (mimeType === 'image/svg+xml') {
|
||||
// Convert SVG to PNG
|
||||
base64Url = await svgBase64UrlToPngDataURL(base64Url);
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
},
|
||||
]);
|
||||
} else if (mimeType.startsWith('video/')) {
|
||||
toast.error('Video files are not supported yet.');
|
||||
break;
|
||||
} else if (mimeType.startsWith('audio/')) {
|
||||
if (!/mpeg|wav/.test(mimeType)) {
|
||||
toast.error('Only mp3 and wav audio files are supported.');
|
||||
break;
|
||||
}
|
||||
|
||||
// plain base64, not a data URL
|
||||
const base64Data = await getFileAsBase64(file, false);
|
||||
addItems([
|
||||
{
|
||||
type: 'audioFile',
|
||||
name: file.name,
|
||||
mimeType,
|
||||
base64Data,
|
||||
},
|
||||
]);
|
||||
} else if (mimeType.startsWith('application/pdf')) {
|
||||
if (config.pdfAsImage && !isSupportVision) {
|
||||
toast(
|
||||
'Multimodal is not supported, PDF will be converted to text instead of image.'
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if (config.pdfAsImage && isSupportVision) {
|
||||
// Convert PDF to images
|
||||
const base64Urls = await convertPDFToImage(file);
|
||||
addItems(
|
||||
base64Urls.map((base64Url) => ({
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
}))
|
||||
);
|
||||
} else {
|
||||
// Convert PDF to text
|
||||
const content = await convertPDFToText(file);
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
if (isSupportVision) {
|
||||
toast.success(
|
||||
'PDF file converted to text. You can also convert it to image, see in Settings.'
|
||||
);
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Because there can be many text file types (like code file), we will not check the mime type
|
||||
// and will just check if the file is not binary.
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
const content = event.target.result as string;
|
||||
if (!isLikelyNotBinary(content)) {
|
||||
toast.error('File is binary. Please upload a text file.');
|
||||
return;
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
}
|
||||
};
|
||||
reader.readAsText(file);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const errorMessage = `Error processing file: ${message}`;
|
||||
toast.error(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
items: items.length > 0 ? items : undefined,
|
||||
addItems,
|
||||
removeItem,
|
||||
clearItems,
|
||||
onFileAdded,
|
||||
};
|
||||
}
|
||||
|
||||
async function getFileAsBase64(file: File, outputUrl = true): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
let result = event.target.result as string;
|
||||
if (!outputUrl) {
|
||||
// remove base64 url prefix and correct characters
|
||||
result = result.substring(result.indexOf(',') + 1);
|
||||
}
|
||||
resolve(result);
|
||||
} else {
|
||||
reject(new Error('Failed to read file.'));
|
||||
}
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
resolve(event.target.result as ArrayBuffer);
|
||||
} else {
|
||||
reject(new Error('Failed to read file.'));
|
||||
}
|
||||
};
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function convertPDFToText(file: File): Promise<string> {
|
||||
const buffer = await getFileAsBuffer(file);
|
||||
const pdf = await pdfjs.getDocument(buffer).promise;
|
||||
const numPages = pdf.numPages;
|
||||
const textContentPromises: Promise<TextContent>[] = [];
|
||||
for (let i = 1; i <= numPages; i++) {
|
||||
textContentPromises.push(
|
||||
pdf.getPage(i).then((page) => page.getTextContent())
|
||||
);
|
||||
}
|
||||
const textContents = await Promise.all(textContentPromises);
|
||||
const textItems = textContents.flatMap((textContent: TextContent) =>
|
||||
textContent.items.map((item) => (item as TextItem).str ?? '')
|
||||
);
|
||||
return textItems.join('\n');
|
||||
}
|
||||
|
||||
// returns list of base64 images
|
||||
async function convertPDFToImage(file: File): Promise<string[]> {
|
||||
const buffer = await getFileAsBuffer(file);
|
||||
const doc = await pdfjs.getDocument(buffer).promise;
|
||||
const pages: Promise<string>[] = [];
|
||||
|
||||
for (let i = 1; i <= doc.numPages; i++) {
|
||||
const page = await doc.getPage(i);
|
||||
const viewport = page.getViewport({ scale: 1.5 });
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
canvas.width = viewport.width;
|
||||
canvas.height = viewport.height;
|
||||
if (!ctx) {
|
||||
throw new Error('Failed to get 2D context from canvas');
|
||||
}
|
||||
const task = page.render({ canvasContext: ctx, viewport: viewport });
|
||||
pages.push(
|
||||
task.promise.then(() => {
|
||||
return canvas.toDataURL();
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return await Promise.all(pages);
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
// This code is a heuristic to determine if a string is likely not binary.
|
||||
// It is necessary because input file can have various mime types which we don't have time to investigate.
|
||||
// For example, a python file can be text/plain, application/x-python, etc.
|
||||
function isLikelyNotBinary(str: string): boolean {
|
||||
const options = {
|
||||
prefixLength: 1024 * 10, // Check the first 10KB of the string
|
||||
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars
|
||||
maxAbsoluteNullBytes: 2,
|
||||
};
|
||||
|
||||
if (!str) {
|
||||
return true; // Empty string is considered "not binary" or trivially text.
|
||||
}
|
||||
|
||||
const sampleLength = Math.min(str.length, options.prefixLength);
|
||||
if (sampleLength === 0) {
|
||||
return true; // Effectively an empty string after considering prefixLength.
|
||||
}
|
||||
|
||||
let suspiciousCharCount = 0;
|
||||
let nullByteCount = 0;
|
||||
|
||||
for (let i = 0; i < sampleLength; i++) {
|
||||
const charCode = str.charCodeAt(i);
|
||||
|
||||
// 1. Check for Unicode Replacement Character (U+FFFD)
|
||||
// This is a strong indicator if the string was created from decoding bytes as UTF-8.
|
||||
if (charCode === 0xfffd) {
|
||||
suspiciousCharCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2. Check for Null Bytes (U+0000)
|
||||
if (charCode === 0x0000) {
|
||||
nullByteCount++;
|
||||
// We also count nulls towards the general suspicious character count,
|
||||
// as they are less common in typical text files.
|
||||
suspiciousCharCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. Check for C0 Control Characters (U+0001 to U+001F)
|
||||
// Exclude common text control characters: TAB (9), LF (10), CR (13).
|
||||
// We can also be a bit lenient with BEL (7) and BS (8) which sometimes appear in logs.
|
||||
if (charCode < 32) {
|
||||
if (
|
||||
charCode !== 9 && // TAB
|
||||
charCode !== 10 && // LF
|
||||
charCode !== 13 && // CR
|
||||
charCode !== 7 && // BEL (Bell) - sometimes in logs
|
||||
charCode !== 8 // BS (Backspace) - less common, but possible
|
||||
) {
|
||||
suspiciousCharCount++;
|
||||
}
|
||||
}
|
||||
// Characters from 32 (space) up to 126 (~) are printable ASCII.
|
||||
// Characters 127 (DEL) is a control character.
|
||||
// Characters >= 128 are extended ASCII / multi-byte Unicode.
|
||||
// If they resulted in U+FFFD, we caught it. Otherwise, they are valid
|
||||
// (though perhaps unusual) Unicode characters from JS's perspective.
|
||||
// The main concern is if those higher characters came from misinterpreting
|
||||
// a single-byte encoding as UTF-8, which again, U+FFFD would usually flag.
|
||||
}
|
||||
|
||||
// Check absolute null byte count
|
||||
if (nullByteCount > options.maxAbsoluteNullBytes) {
|
||||
return false; // Too many null bytes is a strong binary indicator
|
||||
}
|
||||
|
||||
// Check ratio of suspicious characters
|
||||
const ratio = suspiciousCharCount / sampleLength;
|
||||
return ratio <= options.suspiciousCharThresholdRatio;
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
// Converts a Base64URL encoded SVG string to a PNG Data URL using browser Canvas API.
|
||||
function svgBase64UrlToPngDataURL(base64UrlSvg: string): Promise<string> {
|
||||
const backgroundColor = 'white'; // Default background color for PNG
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
const img = new Image();
|
||||
|
||||
img.onload = () => {
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
|
||||
if (!ctx) {
|
||||
reject(new Error('Failed to get 2D canvas context.'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Use provided dimensions or SVG's natural dimensions, with fallbacks
|
||||
// Fallbacks (e.g., 300x300) are for SVGs without explicit width/height
|
||||
// or when naturalWidth/Height might be 0 before full processing.
|
||||
const targetWidth = img.naturalWidth || 300;
|
||||
const targetHeight = img.naturalHeight || 300;
|
||||
|
||||
canvas.width = targetWidth;
|
||||
canvas.height = targetHeight;
|
||||
|
||||
if (backgroundColor) {
|
||||
ctx.fillStyle = backgroundColor;
|
||||
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
||||
}
|
||||
|
||||
ctx.drawImage(img, 0, 0, targetWidth, targetHeight);
|
||||
resolve(canvas.toDataURL('image/png'));
|
||||
};
|
||||
|
||||
img.onerror = () => {
|
||||
reject(
|
||||
new Error('Failed to load SVG image. Ensure the SVG data is valid.')
|
||||
);
|
||||
};
|
||||
|
||||
// Load SVG string into an Image element
|
||||
img.src = base64UrlSvg;
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const errorMessage = `Error converting SVG to PNG: ${message}`;
|
||||
toast.error(errorMessage);
|
||||
reject(new Error(errorMessage));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
import React, { useEffect } from 'react';
|
||||
import { throttle } from '../utils/misc';
|
||||
|
||||
export const scrollToBottom = (requiresNearBottom: boolean, delay?: number) => {
|
||||
const mainScrollElem = document.getElementById('main-scroll');
|
||||
if (!mainScrollElem) return;
|
||||
const spaceToBottom =
|
||||
mainScrollElem.scrollHeight -
|
||||
mainScrollElem.scrollTop -
|
||||
mainScrollElem.clientHeight;
|
||||
if (!requiresNearBottom || spaceToBottom < 100) {
|
||||
setTimeout(
|
||||
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
|
||||
delay ?? 80
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const scrollToBottomThrottled = throttle(scrollToBottom, 80);
|
||||
|
||||
export function useChatScroll(msgListRef: React.RefObject<HTMLDivElement>) {
|
||||
useEffect(() => {
|
||||
if (!msgListRef.current) return;
|
||||
|
||||
const resizeObserver = new ResizeObserver((_) => {
|
||||
scrollToBottomThrottled(true, 10);
|
||||
});
|
||||
|
||||
resizeObserver.observe(msgListRef.current);
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [msgListRef]);
|
||||
}
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
import { useEffect, useRef, useState, useCallback } from 'react';
|
||||
import { throttle } from '../utils/misc';
|
||||
|
||||
// Media Query for detecting "large" screens (matching Tailwind's lg: breakpoint)
|
||||
const LARGE_SCREEN_MQ = '(min-width: 1024px)';
|
||||
|
||||
// Calculates and sets the textarea height based on its scrollHeight
|
||||
const adjustTextareaHeight = throttle(
|
||||
(textarea: HTMLTextAreaElement | null) => {
|
||||
if (!textarea) return;
|
||||
|
||||
// Only perform auto-sizing on large screens
|
||||
if (!window.matchMedia(LARGE_SCREEN_MQ).matches) {
|
||||
// On small screens, reset inline height and max-height styles.
|
||||
// This allows CSS (e.g., `rows` attribute or classes) to control the height,
|
||||
// and enables manual resizing if `resize-vertical` is set.
|
||||
textarea.style.height = ''; // Use 'auto' or '' to reset
|
||||
textarea.style.maxHeight = '';
|
||||
return; // Do not adjust height programmatically on small screens
|
||||
}
|
||||
|
||||
const computedStyle = window.getComputedStyle(textarea);
|
||||
// Get the max-height specified by CSS (e.g., from `lg:max-h-48`)
|
||||
const currentMaxHeight = computedStyle.maxHeight;
|
||||
|
||||
// Temporarily remove max-height to allow scrollHeight to be calculated correctly
|
||||
textarea.style.maxHeight = 'none';
|
||||
// Reset height to 'auto' to measure the actual scrollHeight needed
|
||||
textarea.style.height = 'auto';
|
||||
// Set the height to the calculated scrollHeight
|
||||
textarea.style.height = `${textarea.scrollHeight}px`;
|
||||
// Re-apply the original max-height from CSS to enforce the limit
|
||||
textarea.style.maxHeight = currentMaxHeight;
|
||||
},
|
||||
100
|
||||
); // Throttle to prevent excessive calls
|
||||
|
||||
// Interface describing the API returned by the hook
|
||||
export interface ChatTextareaApi {
|
||||
value: () => string;
|
||||
setValue: (value: string) => void;
|
||||
focus: () => void;
|
||||
ref: React.RefObject<HTMLTextAreaElement>;
|
||||
refOnSubmit: React.MutableRefObject<(() => void) | null>; // Submit handler
|
||||
onInput: (event: React.FormEvent<HTMLTextAreaElement>) => void; // Input handler
|
||||
}
|
||||
|
||||
// This is a workaround to prevent the textarea from re-rendering when the inner content changes
|
||||
// See https://github.com/ggml-org/llama.cpp/pull/12299
|
||||
// combined now with auto-sizing logic.
|
||||
export function useChatTextarea(initValue: string): ChatTextareaApi {
|
||||
const [savedInitValue, setSavedInitValue] = useState<string>(initValue);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const onSubmitRef = useRef<(() => void) | null>(null);
|
||||
|
||||
// Effect to set initial value and height on mount or when initValue changes
|
||||
useEffect(() => {
|
||||
const textarea = textareaRef.current;
|
||||
if (textarea) {
|
||||
if (typeof savedInitValue === 'string' && savedInitValue.length > 0) {
|
||||
textarea.value = savedInitValue;
|
||||
// Call adjustTextareaHeight - it will check screen size internally
|
||||
setTimeout(() => adjustTextareaHeight(textarea), 0);
|
||||
setSavedInitValue(''); // Reset after applying
|
||||
} else {
|
||||
// Adjust height even if there's no initial value (for initial render)
|
||||
setTimeout(() => adjustTextareaHeight(textarea), 0);
|
||||
}
|
||||
}
|
||||
}, [textareaRef, savedInitValue]); // Depend on ref and savedInitValue
|
||||
|
||||
// On input change, we adjust the height of the textarea
|
||||
const handleInput = useCallback(
|
||||
(event: React.FormEvent<HTMLTextAreaElement>) => {
|
||||
// Call adjustTextareaHeight on every input - it will decide whether to act
|
||||
adjustTextareaHeight(event.currentTarget);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
return {
|
||||
// Method to get the current value directly from the textarea
|
||||
value: () => {
|
||||
return textareaRef.current?.value ?? '';
|
||||
},
|
||||
// Method to programmatically set the value and trigger height adjustment
|
||||
setValue: (value: string) => {
|
||||
const textarea = textareaRef.current;
|
||||
if (textarea) {
|
||||
textarea.value = value;
|
||||
// Call adjustTextareaHeight - it will check screen size internally
|
||||
setTimeout(() => adjustTextareaHeight(textarea), 0);
|
||||
}
|
||||
},
|
||||
focus: () => {
|
||||
if (textareaRef.current) {
|
||||
textareaRef.current.focus();
|
||||
}
|
||||
},
|
||||
ref: textareaRef,
|
||||
refOnSubmit: onSubmitRef,
|
||||
onInput: handleInput, // for adjusting height on input
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
import { describe, it, expect } from 'vitest';
|
||||
|
||||
describe('sum test', () => {
|
||||
it('adds 1 + 2 to equal 3', () => {
|
||||
expect(1 + 2).toBe(3);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
@use 'sass:meta';
|
||||
@use 'tailwindcss';
|
||||
|
||||
@plugin 'daisyui' {
|
||||
themes: all;
|
||||
}
|
||||
|
||||
html {
|
||||
scrollbar-gutter: auto;
|
||||
}
|
||||
|
||||
.markdown {
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
h6,
|
||||
ul,
|
||||
ol,
|
||||
li {
|
||||
all: revert;
|
||||
}
|
||||
pre {
|
||||
@apply whitespace-pre-wrap rounded-lg p-2 mb-3;
|
||||
border: 1px solid currentColor;
|
||||
}
|
||||
p {
|
||||
@apply mb-2;
|
||||
}
|
||||
hr {
|
||||
@apply my-4 border-base-content/20 border-1;
|
||||
}
|
||||
table {
|
||||
@apply w-full border-collapse text-sm font-sans my-4 text-base-content;
|
||||
}
|
||||
thead {
|
||||
@apply bg-base-200 text-base-content;
|
||||
}
|
||||
th {
|
||||
@apply border border-base-300 px-4 py-2 text-left font-semibold;
|
||||
}
|
||||
td {
|
||||
@apply border border-base-300 px-4 py-2 align-top;
|
||||
}
|
||||
tbody tr:nth-child(even) {
|
||||
@apply bg-base-100;
|
||||
}
|
||||
tbody tr:hover {
|
||||
@apply bg-base-200;
|
||||
}
|
||||
}
|
||||
|
||||
.btn-mini {
|
||||
@apply cursor-pointer;
|
||||
}
|
||||
.chat-screen {
|
||||
max-width: 900px;
|
||||
}
|
||||
|
||||
.chat-bubble {
|
||||
@apply break-words;
|
||||
}
|
||||
|
||||
.chat-bubble-base-300 {
|
||||
--tw-bg-opacity: 1;
|
||||
--tw-text-opacity: 1;
|
||||
@apply bg-base-300 text-base-content;
|
||||
}
|
||||
|
||||
/* Highlight.js */
|
||||
[data-color-scheme='light'] {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-light');
|
||||
.dark-color {
|
||||
@apply bg-base-content text-base-100;
|
||||
}
|
||||
}
|
||||
[data-color-scheme='dark'] {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
|
||||
}
|
||||
[data-color-scheme='auto'] {
|
||||
@media (prefers-color-scheme: light) {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-light');
|
||||
.dark-color {
|
||||
@apply bg-base-content text-base-100;
|
||||
}
|
||||
}
|
||||
@media (prefers-color-scheme: dark) {
|
||||
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
|
||||
}
|
||||
}
|
||||
.hljs {
|
||||
background: transparent !important;
|
||||
padding: 0.5em !important;
|
||||
}
|
||||
|
||||
.katex-display {
|
||||
margin: 0 0 !important;
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { formatFileSize, getFileTypeLabel, getPreviewText } from '$lib/utils/file-preview';
|
||||
import { FileTypeCategory, MimeTypeText } from '$lib/enums/files';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
id: string;
|
||||
onClick?: (event?: MouseEvent) => void;
|
||||
onRemove?: (id: string) => void;
|
||||
name: string;
|
||||
readonly?: boolean;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
id,
|
||||
onClick,
|
||||
onRemove,
|
||||
name,
|
||||
readonly = false,
|
||||
size,
|
||||
textContent,
|
||||
type
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
{#if type === MimeTypeText.PLAIN || type === FileTypeCategory.TEXT}
|
||||
{#if readonly}
|
||||
<!-- Readonly mode (ChatMessage) -->
|
||||
<button
|
||||
class="cursor-pointer rounded-lg border border-border bg-muted p-3 transition-shadow hover:shadow-md {className} w-full max-w-2xl"
|
||||
onclick={onClick}
|
||||
aria-label={`Preview ${name}`}
|
||||
type="button"
|
||||
>
|
||||
<div class="flex items-start gap-3">
|
||||
<div class="flex min-w-0 flex-1 flex-col items-start text-left">
|
||||
<span class="w-full truncate text-sm font-medium text-foreground">{name}</span>
|
||||
|
||||
{#if size}
|
||||
<span class="text-xs text-muted-foreground">{formatFileSize(size)}</span>
|
||||
{/if}
|
||||
|
||||
{#if textContent && type === 'text'}
|
||||
<div class="relative mt-2 w-full">
|
||||
<div
|
||||
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
|
||||
>
|
||||
{getPreviewText(textContent)}
|
||||
</div>
|
||||
|
||||
{#if textContent.length > 150}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-6 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
{:else}
|
||||
<!-- Non-readonly mode (ChatForm) -->
|
||||
<div class="relative rounded-lg border border-border bg-muted p-3 {className} w-64">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="absolute top-2 right-2 h-6 w-6 bg-white/20 p-0 hover:bg-white/30"
|
||||
onclick={() => onRemove?.(id)}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
|
||||
<div class="pr-8">
|
||||
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
|
||||
|
||||
{#if textContent}
|
||||
<div class="relative">
|
||||
<div
|
||||
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
|
||||
style="max-height: 3.6em; line-height: 1.2em;"
|
||||
>
|
||||
{getPreviewText(textContent)}
|
||||
</div>
|
||||
|
||||
{#if textContent.length > 150}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-4 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{:else}
|
||||
<button
|
||||
class="flex items-center gap-2 gap-3 rounded-lg border border-border bg-muted p-3 {className}"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div
|
||||
class="flex h-8 w-8 items-center justify-center rounded bg-primary/10 text-xs font-medium text-primary"
|
||||
>
|
||||
{getFileTypeLabel(type)}
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<span class="max-w-36 truncate text-sm font-medium text-foreground md:max-w-72">
|
||||
{name}
|
||||
</span>
|
||||
|
||||
{#if size}
|
||||
<span class="text-left text-xs text-muted-foreground">{formatFileSize(size)}</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if !readonly}
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 p-0"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
name: string;
|
||||
preview: string;
|
||||
readonly?: boolean;
|
||||
onRemove?: (id: string) => void;
|
||||
onClick?: (event?: MouseEvent) => void;
|
||||
class?: string;
|
||||
// Customizable size props
|
||||
width?: string;
|
||||
height?: string;
|
||||
imageClass?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
id,
|
||||
name,
|
||||
preview,
|
||||
readonly = false,
|
||||
onRemove,
|
||||
onClick,
|
||||
class: className = '',
|
||||
// Default to small size for form previews
|
||||
width = 'w-auto',
|
||||
height = 'h-24',
|
||||
imageClass = ''
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative overflow-hidden rounded-lg border border-border bg-muted {className}">
|
||||
{#if onClick}
|
||||
<button
|
||||
type="button"
|
||||
class="block h-full w-full rounded-lg focus:ring-2 focus:ring-primary focus:ring-offset-2 focus:outline-none"
|
||||
onclick={onClick}
|
||||
aria-label="Preview {name}"
|
||||
>
|
||||
<img
|
||||
src={preview}
|
||||
alt={name}
|
||||
class="{height} {width} cursor-pointer object-cover {imageClass}"
|
||||
/>
|
||||
</button>
|
||||
{:else}
|
||||
<img
|
||||
src={preview}
|
||||
alt={name}
|
||||
class="{height} {width} cursor-pointer object-cover {imageClass}"
|
||||
/>
|
||||
{/if}
|
||||
|
||||
{#if !readonly}
|
||||
<div
|
||||
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity hover:opacity-100"
|
||||
>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 text-white hover:bg-white/30"
|
||||
onclick={() => onRemove?.(id)}
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
|
@ -0,0 +1,305 @@
|
|||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
|
||||
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
|
||||
import { convertPDFToImage } from '$lib/utils/pdf-processing';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import { formatFileSize } from '$lib/utils/file-preview';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
// Either an uploaded file or a stored attachment
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
// For uploaded files
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(),
|
||||
uploadedFile,
|
||||
attachment,
|
||||
preview,
|
||||
name,
|
||||
type,
|
||||
size,
|
||||
textContent
|
||||
}: Props = $props();
|
||||
|
||||
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
||||
|
||||
let displayPreview = $derived(
|
||||
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
|
||||
);
|
||||
|
||||
let displayType = $derived(
|
||||
uploadedFile?.type ||
|
||||
(attachment?.type === 'imageFile'
|
||||
? 'image'
|
||||
: attachment?.type === 'textFile'
|
||||
? 'text'
|
||||
: attachment?.type === 'audioFile'
|
||||
? attachment.mimeType || 'audio'
|
||||
: attachment?.type === 'pdfFile'
|
||||
? MimeTypeApplication.PDF
|
||||
: type || 'unknown')
|
||||
);
|
||||
|
||||
let displaySize = $derived(uploadedFile?.size || size);
|
||||
|
||||
let displayTextContent = $derived(
|
||||
uploadedFile?.textContent ||
|
||||
(attachment?.type === 'textFile'
|
||||
? attachment.content
|
||||
: attachment?.type === 'pdfFile'
|
||||
? attachment.content
|
||||
: textContent)
|
||||
);
|
||||
|
||||
let isAudio = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
|
||||
);
|
||||
|
||||
let isImage = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
|
||||
);
|
||||
|
||||
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
|
||||
|
||||
let isText = $derived(
|
||||
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
|
||||
);
|
||||
|
||||
let IconComponent = $derived(() => {
|
||||
if (isImage) return Image;
|
||||
if (isText || isPdf) return FileText;
|
||||
if (isAudio) return Music;
|
||||
|
||||
return FileIcon;
|
||||
});
|
||||
|
||||
let pdfViewMode = $state<'text' | 'pages'>('pages');
|
||||
|
||||
let pdfImages = $state<string[]>([]);
|
||||
|
||||
let pdfImagesLoading = $state(false);
|
||||
|
||||
let pdfImagesError = $state<string | null>(null);
|
||||
|
||||
async function loadPdfImages() {
|
||||
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
|
||||
|
||||
pdfImagesLoading = true;
|
||||
pdfImagesError = null;
|
||||
|
||||
try {
|
||||
let file: File | null = null;
|
||||
|
||||
if (uploadedFile?.file) {
|
||||
file = uploadedFile.file;
|
||||
} else if (attachment?.type === 'pdfFile') {
|
||||
// Check if we have pre-processed images
|
||||
if (attachment.images && Array.isArray(attachment.images)) {
|
||||
pdfImages = attachment.images;
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert base64 back to File for processing
|
||||
if (attachment.base64Data) {
|
||||
const base64Data = attachment.base64Data;
|
||||
const byteCharacters = atob(base64Data);
|
||||
const byteNumbers = new Array(byteCharacters.length);
|
||||
for (let i = 0; i < byteCharacters.length; i++) {
|
||||
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||
}
|
||||
const byteArray = new Uint8Array(byteNumbers);
|
||||
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
|
||||
}
|
||||
}
|
||||
|
||||
if (file) {
|
||||
pdfImages = await convertPDFToImage(file);
|
||||
} else {
|
||||
throw new Error('No PDF file available for conversion');
|
||||
}
|
||||
} catch (error) {
|
||||
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
||||
} finally {
|
||||
pdfImagesLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (open && isPdf && pdfViewMode === 'pages') {
|
||||
loadPdfImages();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
|
||||
<Dialog.Header class="flex-shrink-0">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-3">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-5 w-5 text-muted-foreground" />
|
||||
{/if}
|
||||
|
||||
<div>
|
||||
<Dialog.Title class="text-left">{displayName}</Dialog.Title>
|
||||
|
||||
<div class="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<span>{displayType}</span>
|
||||
|
||||
{#if displaySize}
|
||||
<span>•</span>
|
||||
|
||||
<span>{formatFileSize(displaySize)}</span>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isPdf}
|
||||
<div class="flex items-center gap-2">
|
||||
<Button
|
||||
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => (pdfViewMode = 'text')}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
<FileText class="mr-1 h-4 w-4" />
|
||||
|
||||
Text
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
|
||||
size="sm"
|
||||
onclick={() => {
|
||||
pdfViewMode = 'pages';
|
||||
loadPdfImages();
|
||||
}}
|
||||
disabled={pdfImagesLoading}
|
||||
>
|
||||
{#if pdfImagesLoading}
|
||||
<div
|
||||
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
||||
></div>
|
||||
{:else}
|
||||
<Eye class="mr-1 h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
Pages
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="flex-1 overflow-auto">
|
||||
{#if isImage && displayPreview}
|
||||
<div class="flex items-center justify-center">
|
||||
<img
|
||||
src={displayPreview}
|
||||
alt={displayName}
|
||||
class="max-h-full rounded-lg object-contain shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{:else if isPdf && pdfViewMode === 'pages'}
|
||||
{#if pdfImagesLoading}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<div
|
||||
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
|
||||
></div>
|
||||
|
||||
<p class="text-muted-foreground">Converting PDF to images...</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImagesError}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
|
||||
|
||||
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
|
||||
|
||||
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if pdfImages.length > 0}
|
||||
<div class="max-h-[70vh] space-y-4 overflow-auto">
|
||||
{#each pdfImages as image, index (image)}
|
||||
<div class="text-center">
|
||||
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
|
||||
|
||||
<img
|
||||
src={image}
|
||||
alt="PDF Page {index + 1}"
|
||||
class="mx-auto max-w-full rounded-lg shadow-lg"
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
|
||||
<div
|
||||
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
|
||||
>
|
||||
{displayTextContent}
|
||||
</div>
|
||||
{:else if isAudio}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="w-full max-w-md text-center">
|
||||
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
|
||||
{#if attachment?.type === 'audioFile'}
|
||||
<audio
|
||||
controls
|
||||
class="mb-4 w-full"
|
||||
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
|
||||
>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else if uploadedFile?.preview}
|
||||
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
{:else}
|
||||
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
|
||||
{/if}
|
||||
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{displayName}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex items-center justify-center p-8">
|
||||
<div class="text-center">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||
{/if}
|
||||
|
||||
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
|
@ -0,0 +1,185 @@
|
|||
<script lang="ts">
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
// For ChatMessage - stored attachments
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
// For ChatForm - pending uploads
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
// Image size customization
|
||||
imageClass?: string;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
uploadedFiles = $bindable([]),
|
||||
// Default to small size for form previews
|
||||
imageClass = '',
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto'
|
||||
}: Props = $props();
|
||||
|
||||
let displayItems = $derived(getDisplayItems());
|
||||
|
||||
// Preview dialog state
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<{
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
} | null>(null);
|
||||
|
||||
function getDisplayItems() {
|
||||
const items: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
size?: number;
|
||||
preview?: string;
|
||||
type: string;
|
||||
isImage: boolean;
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
attachmentIndex?: number;
|
||||
textContent?: string;
|
||||
}> = [];
|
||||
|
||||
// Add uploaded files (ChatForm)
|
||||
for (const file of uploadedFiles) {
|
||||
items.push({
|
||||
id: file.id,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
preview: file.preview,
|
||||
type: file.type,
|
||||
isImage: getFileTypeCategory(file.type) === FileTypeCategory.IMAGE,
|
||||
uploadedFile: file,
|
||||
textContent: file.textContent
|
||||
});
|
||||
}
|
||||
|
||||
// Add stored attachments (ChatMessage)
|
||||
for (const [index, attachment] of attachments.entries()) {
|
||||
if (attachment.type === 'imageFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
preview: attachment.base64Url,
|
||||
type: 'image',
|
||||
isImage: true,
|
||||
attachment,
|
||||
attachmentIndex: index
|
||||
});
|
||||
} else if (attachment.type === 'textFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'text',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'audioFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: attachment.mimeType || 'audio',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index
|
||||
});
|
||||
} else if (attachment.type === 'pdfFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'application/pdf',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return items;
|
||||
}
|
||||
|
||||
function openPreview(item: (typeof displayItems)[0], event?: Event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
attachment: item.attachment,
|
||||
preview: item.preview,
|
||||
name: item.name,
|
||||
type: item.type,
|
||||
size: item.size,
|
||||
textContent: item.textContent
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if displayItems.length > 0}
|
||||
<div class="flex flex-wrap items-start {readonly ? 'justify-end' : ''} gap-3 {className}">
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isImage && item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentFilePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if previewItem}
|
||||
<ChatAttachmentPreviewDialog
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
preview={previewItem.preview}
|
||||
name={previewItem.name}
|
||||
type={previewItem.type}
|
||||
size={previewItem.size}
|
||||
textContent={previewItem.textContent}
|
||||
/>
|
||||
{/if}
|
||||
|
|
@ -0,0 +1,259 @@
|
|||
<script lang="ts">
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import {
|
||||
ChatAttachmentsList,
|
||||
ChatFormActions,
|
||||
ChatFormFileInputInvisible,
|
||||
ChatFormHelperText,
|
||||
ChatFormTextarea
|
||||
} from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
|
||||
import {
|
||||
AudioRecorder,
|
||||
convertToWav,
|
||||
createAudioFile,
|
||||
isAudioRecordingSupported
|
||||
} from '$lib/utils/audio-recording';
|
||||
import { onMount } from 'svelte';
|
||||
import {
|
||||
FileExtensionAudio,
|
||||
FileExtensionImage,
|
||||
FileExtensionPdf,
|
||||
FileExtensionText,
|
||||
MimeTypeAudio,
|
||||
MimeTypeImage,
|
||||
MimeTypeText
|
||||
} from '$lib/enums/files';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
isLoading?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
onFileUpload?: (files: File[]) => void;
|
||||
onSend?: (message: string, files?: ChatUploadedFile[]) => Promise<boolean>;
|
||||
onStop?: () => void;
|
||||
showHelperText?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
}
|
||||
|
||||
let {
|
||||
class: className,
|
||||
disabled = false,
|
||||
isLoading = false,
|
||||
onFileRemove,
|
||||
onFileUpload,
|
||||
onSend,
|
||||
onStop,
|
||||
showHelperText = true,
|
||||
uploadedFiles = $bindable([])
|
||||
}: Props = $props();
|
||||
|
||||
let audioRecorder: AudioRecorder | undefined;
|
||||
let currentConfig = $derived(config());
|
||||
let fileAcceptString = $state<string | undefined>(undefined);
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state('');
|
||||
let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500);
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
function getAcceptStringForFileType(fileType: FileTypeCategory): string {
|
||||
switch (fileType) {
|
||||
case FileTypeCategory.IMAGE:
|
||||
return [...Object.values(FileExtensionImage), ...Object.values(MimeTypeImage)].join(',');
|
||||
case FileTypeCategory.AUDIO:
|
||||
return [...Object.values(FileExtensionAudio), ...Object.values(MimeTypeAudio)].join(',');
|
||||
case FileTypeCategory.PDF:
|
||||
return [...Object.values(FileExtensionPdf), ...Object.values(MimeTypeApplication)].join(
|
||||
','
|
||||
);
|
||||
case FileTypeCategory.TEXT:
|
||||
return [...Object.values(FileExtensionText), MimeTypeText.PLAIN].join(',');
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
function handleFileSelect(files: File[]) {
|
||||
onFileUpload?.(files);
|
||||
}
|
||||
|
||||
function handleFileUpload(fileType?: FileTypeCategory) {
|
||||
if (fileType) {
|
||||
fileAcceptString = getAcceptStringForFileType(fileType);
|
||||
} else {
|
||||
fileAcceptString = undefined;
|
||||
}
|
||||
|
||||
// Use setTimeout to ensure the accept attribute is applied before opening dialog
|
||||
setTimeout(() => {
|
||||
fileInputRef?.click();
|
||||
}, 10);
|
||||
}
|
||||
|
||||
async function handleKeydown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
|
||||
|
||||
const messageToSend = message.trim();
|
||||
const filesToSend = [...uploadedFiles];
|
||||
|
||||
message = '';
|
||||
uploadedFiles = [];
|
||||
|
||||
textareaRef?.resetHeight();
|
||||
|
||||
const success = await onSend?.(messageToSend, filesToSend);
|
||||
|
||||
if (!success) {
|
||||
message = messageToSend;
|
||||
uploadedFiles = filesToSend;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handlePaste(event: ClipboardEvent) {
|
||||
if (!event.clipboardData) return;
|
||||
|
||||
const files = Array.from(event.clipboardData.items)
|
||||
.filter((item) => item.kind === 'file')
|
||||
.map((item) => item.getAsFile())
|
||||
.filter((file): file is File => file !== null);
|
||||
|
||||
if (files.length > 0) {
|
||||
event.preventDefault();
|
||||
onFileUpload?.(files);
|
||||
return;
|
||||
}
|
||||
|
||||
const text = event.clipboardData.getData(MimeTypeText.PLAIN);
|
||||
|
||||
if (
|
||||
text.length > 0 &&
|
||||
pasteLongTextToFileLength > 0 &&
|
||||
text.length > pasteLongTextToFileLength
|
||||
) {
|
||||
event.preventDefault();
|
||||
|
||||
const textFile = new File([text], 'Pasted', {
|
||||
type: MimeTypeText.PLAIN
|
||||
});
|
||||
|
||||
onFileUpload?.([textFile]);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMicClick() {
|
||||
if (!audioRecorder || !recordingSupported) {
|
||||
console.warn('Audio recording not supported');
|
||||
return;
|
||||
}
|
||||
|
||||
if (isRecording) {
|
||||
try {
|
||||
const audioBlob = await audioRecorder.stopRecording();
|
||||
const wavBlob = await convertToWav(audioBlob);
|
||||
const audioFile = createAudioFile(wavBlob);
|
||||
|
||||
onFileUpload?.([audioFile]);
|
||||
isRecording = false;
|
||||
} catch (error) {
|
||||
console.error('Failed to stop recording:', error);
|
||||
isRecording = false;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
await audioRecorder.startRecording();
|
||||
isRecording = true;
|
||||
} catch (error) {
|
||||
console.error('Failed to start recording:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStop() {
|
||||
onStop?.();
|
||||
}
|
||||
|
||||
async function handleSubmit(event: SubmitEvent) {
|
||||
event.preventDefault();
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
|
||||
|
||||
const messageToSend = message.trim();
|
||||
const filesToSend = [...uploadedFiles];
|
||||
|
||||
message = '';
|
||||
uploadedFiles = [];
|
||||
|
||||
textareaRef?.resetHeight();
|
||||
|
||||
const success = await onSend?.(messageToSend, filesToSend);
|
||||
|
||||
if (!success) {
|
||||
message = messageToSend;
|
||||
uploadedFiles = filesToSend;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
recordingSupported = isAudioRecordingSupported();
|
||||
audioRecorder = new AudioRecorder();
|
||||
});
|
||||
|
||||
afterNavigate(() => {
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (previousIsLoading && !isLoading) {
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
}
|
||||
|
||||
previousIsLoading = isLoading;
|
||||
});
|
||||
</script>
|
||||
|
||||
<ChatFormFileInputInvisible
|
||||
bind:this={fileInputRef}
|
||||
bind:accept={fileAcceptString}
|
||||
onFileSelect={handleFileSelect}
|
||||
/>
|
||||
|
||||
<form
|
||||
onsubmit={handleSubmit}
|
||||
class="{INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {className}"
|
||||
>
|
||||
<ChatAttachmentsList bind:uploadedFiles {onFileRemove} class="mb-3 px-5 pt-5" />
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl px-5 py-3 shadow-sm transition-all focus-within:shadow-md"
|
||||
onpaste={handlePaste}
|
||||
>
|
||||
<ChatFormTextarea
|
||||
bind:this={textareaRef}
|
||||
bind:value={message}
|
||||
onKeydown={handleKeydown}
|
||||
{disabled}
|
||||
/>
|
||||
|
||||
<ChatFormActions
|
||||
canSend={message.trim().length > 0 || uploadedFiles.length > 0}
|
||||
{disabled}
|
||||
{isLoading}
|
||||
{isRecording}
|
||||
onFileUpload={handleFileUpload}
|
||||
onMicClick={handleMicClick}
|
||||
onStop={handleStop}
|
||||
/>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<ChatFormHelperText show={showHelperText} />
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
<script lang="ts">
|
||||
import { Paperclip, Image, FileText, File, Volume2 } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { TOOLTIP_DELAY_DURATION } from '$lib/constants/tooltip-config';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { supportsAudio, supportsVision } from '$lib/stores/server.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
onFileUpload?: (fileType?: FileTypeCategory) => void;
|
||||
}
|
||||
|
||||
let { class: className = '', disabled = false, onFileUpload }: Props = $props();
|
||||
|
||||
const fileUploadTooltipText = $derived.by(() => {
|
||||
return !supportsVision()
|
||||
? 'Text files and PDFs supported. Images, audio, and video require vision models.'
|
||||
: 'Attach files';
|
||||
});
|
||||
|
||||
function handleFileUpload(fileType?: FileTypeCategory) {
|
||||
onFileUpload?.(fileType);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<DropdownMenu.Root>
|
||||
<DropdownMenu.Trigger name="Attach files">
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full bg-transparent p-0 text-muted-foreground hover:bg-foreground/10 hover:text-foreground"
|
||||
{disabled}
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">Attach files</span>
|
||||
|
||||
<Paperclip class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{fileUploadTooltipText}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-48">
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="images-button flex cursor-pointer items-center gap-2"
|
||||
disabled={!supportsVision()}
|
||||
onclick={() => handleFileUpload(FileTypeCategory.IMAGE)}
|
||||
>
|
||||
<Image class="h-4 w-4" />
|
||||
|
||||
<span>Images</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
{#if !supportsVision()}
|
||||
<Tooltip.Content>
|
||||
<p>Images require vision models to be processed</p>
|
||||
</Tooltip.Content>
|
||||
{/if}
|
||||
</Tooltip.Root>
|
||||
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="audio-button flex cursor-pointer items-center gap-2"
|
||||
disabled={!supportsAudio()}
|
||||
onclick={() => handleFileUpload(FileTypeCategory.AUDIO)}
|
||||
>
|
||||
<Volume2 class="h-4 w-4" />
|
||||
|
||||
<span>Audio Files</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
{#if !supportsAudio()}
|
||||
<Tooltip.Content>
|
||||
<p>Audio files require audio models to be processed</p>
|
||||
</Tooltip.Content>
|
||||
{/if}
|
||||
</Tooltip.Root>
|
||||
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => handleFileUpload(FileTypeCategory.TEXT)}
|
||||
>
|
||||
<FileText class="h-4 w-4" />
|
||||
|
||||
<span>Text Files</span>
|
||||
</DropdownMenu.Item>
|
||||
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => handleFileUpload(FileTypeCategory.PDF)}
|
||||
>
|
||||
<File class="h-4 w-4" />
|
||||
|
||||
<span>PDF Files</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
{#if !supportsVision()}
|
||||
<Tooltip.Content>
|
||||
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
|
||||
</Tooltip.Content>
|
||||
{/if}
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
<script lang="ts">
|
||||
import { Mic } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { supportsAudio } from '$lib/stores/server.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
isLoading?: boolean;
|
||||
isRecording?: boolean;
|
||||
onMicClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
isLoading = false,
|
||||
isRecording = false,
|
||||
onMicClick
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<Tooltip.Root delayDuration={100}>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
class="h-8 w-8 rounded-full p-0 {isRecording
|
||||
? 'animate-pulse bg-red-500 text-white hover:bg-red-600'
|
||||
: 'bg-transparent text-muted-foreground hover:bg-foreground/10 hover:text-foreground'} {!supportsAudio()
|
||||
? 'cursor-not-allowed opacity-50'
|
||||
: ''}"
|
||||
disabled={disabled || isLoading || !supportsAudio()}
|
||||
onclick={onMicClick}
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{isRecording ? 'Stop recording' : 'Start recording'}</span>
|
||||
|
||||
<Mic class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
{#if !supportsAudio()}
|
||||
<Tooltip.Content>
|
||||
<p>Current model does not support audio</p>
|
||||
</Tooltip.Content>
|
||||
{/if}
|
||||
</Tooltip.Root>
|
||||
</div>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue