Merge remote-tracking branch 'origin/master' into kimi-k2.5
This commit is contained in:
commit
16010cba64
|
|
@ -295,6 +295,7 @@ jobs:
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||||
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
|
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||||
|
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
||||||
|
|
||||||
- name: Build (no OpenMP)
|
- name: Build (no OpenMP)
|
||||||
|
|
@ -307,6 +308,7 @@ jobs:
|
||||||
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
|
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||||
-DGGML_OPENMP=OFF
|
-DGGML_OPENMP=OFF
|
||||||
|
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,6 @@ on:
|
||||||
description: 'Commit SHA1 to build'
|
description: 'Commit SHA1 to build'
|
||||||
required: false
|
required: false
|
||||||
type: string
|
type: string
|
||||||
slow_tests:
|
|
||||||
description: 'Run slow tests'
|
|
||||||
required: true
|
|
||||||
type: boolean
|
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
|
@ -101,119 +97,3 @@ jobs:
|
||||||
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||||
run: npm run test:e2e
|
run: npm run test:e2e
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
server-build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
|
|
||||||
build_type: [RelWithDebInfo]
|
|
||||||
include:
|
|
||||||
- build_type: Release
|
|
||||||
sanitizer: ""
|
|
||||||
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get -y install \
|
|
||||||
build-essential \
|
|
||||||
xxd \
|
|
||||||
git \
|
|
||||||
cmake \
|
|
||||||
curl \
|
|
||||||
wget \
|
|
||||||
language-pack-en \
|
|
||||||
libssl-dev
|
|
||||||
|
|
||||||
- name: Clone
|
|
||||||
id: checkout
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
|
||||||
|
|
||||||
- name: Python setup
|
|
||||||
id: setup_python
|
|
||||||
uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Tests dependencies
|
|
||||||
id: test_dependencies
|
|
||||||
run: |
|
|
||||||
pip install -r tools/server/tests/requirements.txt
|
|
||||||
|
|
||||||
- name: Setup Node.js for WebUI
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22"
|
|
||||||
cache: "npm"
|
|
||||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
|
||||||
|
|
||||||
- name: Install WebUI dependencies
|
|
||||||
run: npm ci
|
|
||||||
working-directory: tools/server/webui
|
|
||||||
|
|
||||||
- name: Build WebUI
|
|
||||||
run: npm run build
|
|
||||||
working-directory: tools/server/webui
|
|
||||||
|
|
||||||
- name: Build (no OpenMP)
|
|
||||||
id: cmake_build_no_openmp
|
|
||||||
if: ${{ matrix.sanitizer == 'THREAD' }}
|
|
||||||
run: |
|
|
||||||
cmake -B build \
|
|
||||||
-DGGML_NATIVE=OFF \
|
|
||||||
-DLLAMA_BUILD_SERVER=ON \
|
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
|
||||||
-DGGML_OPENMP=OFF ;
|
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
|
||||||
|
|
||||||
- name: Build (sanitizers)
|
|
||||||
id: cmake_build_sanitizers
|
|
||||||
if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
|
|
||||||
run: |
|
|
||||||
cmake -B build \
|
|
||||||
-DGGML_NATIVE=OFF \
|
|
||||||
-DLLAMA_BUILD_SERVER=ON \
|
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
|
||||||
|
|
||||||
- name: Build (sanitizers)
|
|
||||||
id: cmake_build
|
|
||||||
if: ${{ matrix.sanitizer == '' }}
|
|
||||||
run: |
|
|
||||||
cmake -B build \
|
|
||||||
-DGGML_NATIVE=OFF \
|
|
||||||
-DLLAMA_BUILD_SERVER=ON \
|
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
|
||||||
|
|
||||||
- name: Tests
|
|
||||||
id: server_integration_tests
|
|
||||||
if: ${{ matrix.sanitizer == '' }}
|
|
||||||
env:
|
|
||||||
GITHUB_ACTIONS: "true"
|
|
||||||
run: |
|
|
||||||
cd tools/server/tests
|
|
||||||
./tests.sh
|
|
||||||
|
|
||||||
- name: Tests (sanitizers)
|
|
||||||
id: server_integration_tests_sanitizers
|
|
||||||
if: ${{ matrix.sanitizer != '' }}
|
|
||||||
run: |
|
|
||||||
cd tools/server/tests
|
|
||||||
LLAMA_SANITIZE=1 ./tests.sh
|
|
||||||
|
|
||||||
- name: Slow tests
|
|
||||||
id: server_integration_tests_slow
|
|
||||||
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
|
|
||||||
run: |
|
|
||||||
cd tools/server/tests
|
|
||||||
SLOW_TESTS=1 ./tests.sh
|
|
||||||
|
|
|
||||||
|
|
@ -81,18 +81,14 @@ jobs:
|
||||||
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
|
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
|
||||||
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
|
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
|
||||||
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
|
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||||
|
|
||||||
- name: Python setup
|
- name: Python setup
|
||||||
id: setup_python
|
id: setup_python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
pip-install: -r tools/server/tests/requirements.txt
|
||||||
- name: Tests dependencies
|
|
||||||
id: test_dependencies
|
|
||||||
run: |
|
|
||||||
pip install -r tools/server/tests/requirements.txt
|
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
id: server_integration_tests
|
id: server_integration_tests
|
||||||
|
|
@ -102,6 +98,14 @@ jobs:
|
||||||
export ${{ matrix.extra_args }}
|
export ${{ matrix.extra_args }}
|
||||||
pytest -v -x -m "not slow"
|
pytest -v -x -m "not slow"
|
||||||
|
|
||||||
|
- name: Slow tests
|
||||||
|
id: server_integration_tests_slow
|
||||||
|
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
|
||||||
|
run: |
|
||||||
|
cd tools/server/tests
|
||||||
|
export ${{ matrix.extra_args }}
|
||||||
|
SLOW_TESTS=1 pytest -v -x
|
||||||
|
|
||||||
server-windows:
|
server-windows:
|
||||||
runs-on: windows-2022
|
runs-on: windows-2022
|
||||||
|
|
||||||
|
|
@ -124,11 +128,7 @@ jobs:
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
pip-install: -r tools/server/tests/requirements.txt
|
||||||
- name: Tests dependencies
|
|
||||||
id: test_dependencies
|
|
||||||
run: |
|
|
||||||
pip install -r tools/server/tests/requirements.txt
|
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
id: server_integration_tests
|
id: server_integration_tests
|
||||||
|
|
|
||||||
|
|
@ -43,10 +43,15 @@ static __device__ void rope_yarn(
|
||||||
template <bool forward, bool has_ff, typename T, typename D>
|
template <bool forward, bool has_ff, typename T, typename D>
|
||||||
static __global__ void rope_norm(const T * x,
|
static __global__ void rope_norm(const T * x,
|
||||||
D * dst,
|
D * dst,
|
||||||
const int ne0,
|
const int ne00,
|
||||||
const int ne1,
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
const int s1,
|
const int s1,
|
||||||
const int s2,
|
const int s2,
|
||||||
|
const int s3,
|
||||||
const int n_dims,
|
const int n_dims,
|
||||||
const int32_t * pos,
|
const int32_t * pos,
|
||||||
const float freq_scale,
|
const float freq_scale,
|
||||||
|
|
@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x,
|
||||||
const int set_rows_stride) {
|
const int set_rows_stride) {
|
||||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne00) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||||
const int channel_x = row_dst / ne1;
|
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||||
|
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||||
int idst = row_dst * ne0 + i0;
|
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
|
||||||
|
|
||||||
|
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||||
|
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||||
if (set_rows_stride != 0) {
|
if (set_rows_stride != 0) {
|
||||||
idst = row_x * ne0 + i0;
|
idst = i1 * s1 + i0;
|
||||||
idst += row_indices[channel_x] * set_rows_stride;
|
idst += row_indices[i2] * set_rows_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & store_coaelsced = [&](float x0, float x1) {
|
const auto & store_coaelsced = [&](float x0, float x1) {
|
||||||
|
|
@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
|
@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x,
|
||||||
template <bool forward, bool has_ff, typename T, typename D>
|
template <bool forward, bool has_ff, typename T, typename D>
|
||||||
static __global__ void rope_neox(const T * x,
|
static __global__ void rope_neox(const T * x,
|
||||||
D * dst,
|
D * dst,
|
||||||
const int ne0,
|
const int ne00,
|
||||||
const int ne1,
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
const int s1,
|
const int s1,
|
||||||
const int s2,
|
const int s2,
|
||||||
|
const int s3,
|
||||||
const int n_dims,
|
const int n_dims,
|
||||||
const int32_t * pos,
|
const int32_t * pos,
|
||||||
const float freq_scale,
|
const float freq_scale,
|
||||||
|
|
@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x,
|
||||||
const int set_rows_stride) {
|
const int set_rows_stride) {
|
||||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne00) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||||
const int channel_x = row_dst / ne1;
|
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||||
|
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||||
|
|
||||||
int idst = row_dst * ne0 + i0 / 2;
|
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||||
|
|
||||||
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
// Fusion optimization: ROPE + VIEW + SET_ROWS.
|
||||||
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
||||||
if (set_rows_stride != 0) {
|
if (set_rows_stride != 0) {
|
||||||
idst = row_x * ne0 + i0 / 2;
|
idst = i1 * s1 + i0 / 2;
|
||||||
idst += row_indices[channel_x] * set_rows_stride;
|
idst += row_indices[i2] * set_rows_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
if (i0 >= n_dims) {
|
||||||
|
|
@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
|
@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x,
|
||||||
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool forward, bool has_ff, typename T>
|
template <bool forward, bool has_ff, typename T>
|
||||||
static __global__ void rope_multi(
|
static __global__ void rope_multi(const T * x,
|
||||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
T * dst,
|
||||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
const int ne00,
|
||||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
const int ne01,
|
||||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
|
const int s1,
|
||||||
|
const int s2,
|
||||||
|
const int s3,
|
||||||
|
const int n_dims,
|
||||||
|
const int32_t * pos,
|
||||||
|
const float freq_scale,
|
||||||
|
const float ext_factor,
|
||||||
|
const float attn_factor,
|
||||||
|
const rope_corr_dims corr_dims,
|
||||||
|
const float theta_scale,
|
||||||
|
const float * freq_factors,
|
||||||
|
const mrope_sections sections,
|
||||||
|
const bool is_imrope) {
|
||||||
|
const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne00) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||||
const int channel_x = row_dst / ne1;
|
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||||
|
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||||
|
|
||||||
const int idst = row_dst*ne0 + i0/2;
|
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||||
|
|
||||||
if (i0 >= n_dims) {
|
if (i0 >= n_dims) {
|
||||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||||
|
|
@ -200,27 +229,24 @@ static __global__ void rope_multi(
|
||||||
|
|
||||||
float theta_base = 0.0;
|
float theta_base = 0.0;
|
||||||
if (is_imrope) {
|
if (is_imrope) {
|
||||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
|
||||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
|
||||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
|
||||||
} else {
|
} else {
|
||||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (sector < sections.v[0]) {
|
if (sector < sections.v[0]) {
|
||||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
|
||||||
}
|
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
|
||||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||||
}
|
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
|
||||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
} else if (sector >= sec_w + sections.v[2]) {
|
||||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
|
||||||
}
|
|
||||||
else if (sector >= sec_w + sections.v[2]) {
|
|
||||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -238,37 +264,53 @@ static __global__ void rope_multi(
|
||||||
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool forward, bool has_ff, typename T>
|
template <bool forward, bool has_ff, typename T>
|
||||||
static __global__ void rope_vision(
|
static __global__ void rope_vision(const T * x,
|
||||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
T * dst,
|
||||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
const int ne00,
|
||||||
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
|
const int s1,
|
||||||
|
const int s2,
|
||||||
|
const int s3,
|
||||||
|
const int n_dims,
|
||||||
|
const int32_t * pos,
|
||||||
|
const float freq_scale,
|
||||||
|
const float ext_factor,
|
||||||
|
const float attn_factor,
|
||||||
|
const rope_corr_dims corr_dims,
|
||||||
|
const float theta_scale,
|
||||||
|
const float * freq_factors,
|
||||||
|
const mrope_sections sections) {
|
||||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
if (i0 >= ne00) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
const int row_x = row_dst % ne1;
|
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||||
const int channel_x = row_dst / ne1;
|
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||||
|
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||||
|
|
||||||
const int idst = row_dst*ne0 + i0/2;
|
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||||
|
|
||||||
const int sect_dims = sections.v[0] + sections.v[1];
|
const int sect_dims = sections.v[0] + sections.v[1];
|
||||||
const int sec_w = sections.v[1] + sections.v[0];
|
const int sec_w = sections.v[1] + sections.v[0];
|
||||||
const int sector = (i0 / 2) % sect_dims;
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
|
|
||||||
float theta_base = 0.0;
|
float theta_base = 0.0;
|
||||||
if (sector < sections.v[0]) {
|
if (sector < sections.v[0]) {
|
||||||
const int p = sector;
|
const int p = sector;
|
||||||
theta_base = pos[channel_x]*powf(theta_scale, p);
|
theta_base = pos[i2] * powf(theta_scale, p);
|
||||||
}
|
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
||||||
const int p = sector - sections.v[0];
|
const int p = sector - sections.v[0];
|
||||||
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
|
theta_base = pos[i2 + ne02] * powf(theta_scale, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||||
|
|
@ -288,10 +330,15 @@ static __global__ void rope_vision(
|
||||||
template <bool forward, typename T, typename D>
|
template <bool forward, typename T, typename D>
|
||||||
static void rope_norm_cuda(const T * x,
|
static void rope_norm_cuda(const T * x,
|
||||||
D * dst,
|
D * dst,
|
||||||
const int ne0,
|
const int ne00,
|
||||||
const int ne1,
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
const int s1,
|
const int s1,
|
||||||
const int s2,
|
const int s2,
|
||||||
|
const int s3,
|
||||||
const int n_dims,
|
const int n_dims,
|
||||||
const int nr,
|
const int nr,
|
||||||
const int32_t * pos,
|
const int32_t * pos,
|
||||||
|
|
@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x,
|
||||||
const int64_t * row_indices,
|
const int64_t * row_indices,
|
||||||
const int set_rows_stride,
|
const int set_rows_stride,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne00 % 2 == 0);
|
||||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
freq_factors, row_indices, set_rows_stride);
|
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||||
} else {
|
} else {
|
||||||
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
freq_factors, row_indices, set_rows_stride);
|
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool forward, typename T, typename D>
|
template <bool forward, typename T, typename D>
|
||||||
static void rope_neox_cuda(const T * x,
|
static void rope_neox_cuda(const T * x,
|
||||||
D * dst,
|
D * dst,
|
||||||
const int ne0,
|
const int ne00,
|
||||||
const int ne1,
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
const int s1,
|
const int s1,
|
||||||
const int s2,
|
const int s2,
|
||||||
|
const int s3,
|
||||||
const int n_dims,
|
const int n_dims,
|
||||||
const int nr,
|
const int nr,
|
||||||
const int32_t * pos,
|
const int32_t * pos,
|
||||||
|
|
@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x,
|
||||||
const int64_t * row_indices,
|
const int64_t * row_indices,
|
||||||
const int set_rows_stride,
|
const int set_rows_stride,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne00 % 2 == 0);
|
||||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
freq_factors, row_indices, set_rows_stride);
|
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||||
} else {
|
} else {
|
||||||
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
freq_factors, row_indices, set_rows_stride);
|
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool forward, typename T>
|
template <bool forward, typename T>
|
||||||
static void rope_multi_cuda(
|
static void rope_multi_cuda(const T * x,
|
||||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
T * dst,
|
||||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
const int ne00,
|
||||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
|
const int ne01,
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
|
const int s1,
|
||||||
|
const int s2,
|
||||||
|
const int s3,
|
||||||
|
const int n_dims,
|
||||||
|
const int nr,
|
||||||
|
const int32_t * pos,
|
||||||
|
const float freq_scale,
|
||||||
|
const float freq_base,
|
||||||
|
const float ext_factor,
|
||||||
|
const float attn_factor,
|
||||||
|
const rope_corr_dims corr_dims,
|
||||||
|
const float * freq_factors,
|
||||||
|
const mrope_sections sections,
|
||||||
|
const bool is_imrope,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ne00 % 2 == 0);
|
||||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||||
} else {
|
} else {
|
||||||
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool forward, typename T>
|
template <bool forward, typename T>
|
||||||
static void rope_vision_cuda(
|
static void rope_vision_cuda(const T * x,
|
||||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
T * dst,
|
||||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
const int ne00,
|
||||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
const int ne01,
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
const int ne02,
|
||||||
|
const int s01,
|
||||||
|
const int s02,
|
||||||
|
const int s03,
|
||||||
|
const int s1,
|
||||||
|
const int s2,
|
||||||
|
const int s3,
|
||||||
|
const int n_dims,
|
||||||
|
const int nr,
|
||||||
|
const int32_t * pos,
|
||||||
|
const float freq_scale,
|
||||||
|
const float freq_base,
|
||||||
|
const float ext_factor,
|
||||||
|
const float attn_factor,
|
||||||
|
const rope_corr_dims corr_dims,
|
||||||
|
const float * freq_factors,
|
||||||
|
const mrope_sections sections,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ne00 % 2 == 0);
|
||||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nr, n_blocks_x, 1);
|
const dim3 block_nums(nr, n_blocks_x, 1);
|
||||||
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
|
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
|
||||||
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
|
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
|
||||||
|
|
@ -398,11 +487,11 @@ static void rope_vision_cuda(
|
||||||
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||||
} else {
|
} else {
|
||||||
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||||
|
|
||||||
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||||
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||||
|
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
|
||||||
|
|
||||||
|
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
|
||||||
|
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
|
||||||
|
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
|
|
@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
||||||
// compute
|
// compute
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||||
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||||
freq_factors, row_indices, set_rows_stride, stream);
|
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||||
|
set_rows_stride, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||||
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||||
freq_factors, row_indices, set_rows_stride, stream);
|
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||||
|
set_rows_stride, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||||
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||||
freq_factors, row_indices, set_rows_stride, stream);
|
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||||
|
set_rows_stride, stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
} else if (is_mrope && !is_vision) {
|
} else if (is_mrope && !is_vision) {
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
rope_multi_cuda<forward>(
|
rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
corr_dims, freq_factors, sections, is_imrope, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16) {
|
||||||
rope_multi_cuda<forward>(
|
rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
corr_dims, freq_factors, sections, is_imrope, stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
} else if (is_vision) {
|
} else if (is_vision) {
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
rope_vision_cuda<forward>(
|
rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
corr_dims, freq_factors, sections, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16) {
|
||||||
rope_vision_cuda<forward>(
|
rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
|
||||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
|
||||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
corr_dims, freq_factors, sections, stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||||
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||||
freq_factors, row_indices, set_rows_stride, stream);
|
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||||
|
set_rows_stride, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||||
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||||
freq_factors, row_indices, set_rows_stride, stream);
|
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||||
|
set_rows_stride, stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||||
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
|
||||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||||
freq_factors, row_indices, set_rows_stride, stream);
|
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||||
|
set_rows_stride, stream);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -394,7 +394,7 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
|
||||||
[encoder endEncoding];
|
[encoder endEncoding];
|
||||||
|
|
||||||
ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
|
ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
|
||||||
ggml_metal_event_record(ctx_src, ev_cpy);
|
ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
|
||||||
|
|
||||||
[cmd_buf commit];
|
[cmd_buf commit];
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1392,34 +1392,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
|
||||||
GGML_UNUSED(op);
|
GGML_UNUSED(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
||||||
ggml_metal_library_t lib,
|
|
||||||
ggml_op op,
|
|
||||||
int32_t n_fuse,
|
|
||||||
bool row) {
|
|
||||||
char base[256];
|
char base[256];
|
||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
const char * op_str = "undefined";
|
int op_num = -1;
|
||||||
switch (op) {
|
|
||||||
case GGML_OP_ADD: op_str = "add"; break;
|
switch (op->op) {
|
||||||
case GGML_OP_SUB: op_str = "sub"; break;
|
case GGML_OP_ADD: op_num = 0; break;
|
||||||
case GGML_OP_MUL: op_str = "mul"; break;
|
case GGML_OP_SUB: op_num = 1; break;
|
||||||
case GGML_OP_DIV: op_str = "div"; break;
|
case GGML_OP_MUL: op_num = 2; break;
|
||||||
|
case GGML_OP_DIV: op_num = 3; break;
|
||||||
default: GGML_ABORT("fatal error");
|
default: GGML_ABORT("fatal error");
|
||||||
};
|
};
|
||||||
|
|
||||||
if (row) {
|
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||||
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
|
const char * t1_str = ggml_type_name(op->src[1]->type);
|
||||||
} else {
|
const char * t_str = ggml_type_name(op->type);
|
||||||
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
|
||||||
}
|
|
||||||
|
|
||||||
snprintf(name, 256, "%s", base);
|
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
|
||||||
|
|
||||||
|
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
|
||||||
|
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
if (!res.pipeline) {
|
if (!res.pipeline) {
|
||||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||||
|
|
||||||
|
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
||||||
|
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
|
||||||
|
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||||
|
|
||||||
|
ggml_metal_cv_free(cv);
|
||||||
|
}
|
||||||
|
|
||||||
|
res.c4 = is_c4;
|
||||||
|
res.cnt = is_rb;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
int op_num = -1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_ADD: op_num = 0; break;
|
||||||
|
case GGML_OP_SUB: op_num = 1; break;
|
||||||
|
case GGML_OP_MUL: op_num = 2; break;
|
||||||
|
case GGML_OP_DIV: op_num = 3; break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
};
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
|
||||||
|
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (!res.pipeline) {
|
||||||
|
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||||
|
|
||||||
|
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
||||||
|
ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
|
||||||
|
ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||||
|
|
||||||
|
ggml_metal_cv_free(cv);
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
|
||||||
int nr1;
|
int nr1;
|
||||||
|
|
||||||
size_t smem;
|
size_t smem;
|
||||||
|
|
||||||
|
bool c4;
|
||||||
|
bool cnt;
|
||||||
};
|
};
|
||||||
|
|
||||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
|
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
|
||||||
|
|
@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
|
||||||
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||||
|
|
|
||||||
|
|
@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
|
||||||
|
|
||||||
struct ggml_metal_pipeline_with_params res = {
|
struct ggml_metal_pipeline_with_params res = {
|
||||||
/*.pipeline =*/ nil,
|
/*.pipeline =*/ nil,
|
||||||
|
/*.nsg =*/ 0,
|
||||||
/*.nr0 =*/ 0,
|
/*.nr0 =*/ 0,
|
||||||
/*.nr1 =*/ 0,
|
/*.nr1 =*/ 0,
|
||||||
/*.nsg =*/ 0,
|
|
||||||
/*.smem =*/ 0,
|
/*.smem =*/ 0,
|
||||||
|
/*.c4 =*/ false,
|
||||||
|
/*.cnt =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||||
|
|
@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||||
struct ggml_metal_pipeline_with_params res = {
|
struct ggml_metal_pipeline_with_params res = {
|
||||||
/*.pipeline =*/ nil,
|
/*.pipeline =*/ nil,
|
||||||
|
/*.nsg =*/ 0,
|
||||||
/*.nr0 =*/ 0,
|
/*.nr0 =*/ 0,
|
||||||
/*.nr1 =*/ 0,
|
/*.nr1 =*/ 0,
|
||||||
/*.nsg =*/ 0,
|
|
||||||
/*.smem =*/ 0,
|
/*.smem =*/ 0,
|
||||||
|
/*.c4 =*/ false,
|
||||||
|
/*.cnt =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
[lib->lock lock];
|
[lib->lock lock];
|
||||||
|
|
@ -1054,7 +1058,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_ADD_ID:
|
case GGML_OP_ADD_ID:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@
|
||||||
#define FC_SSM_CONV 900
|
#define FC_SSM_CONV 900
|
||||||
#define FC_SOLVE_TRI 1000
|
#define FC_SOLVE_TRI 1000
|
||||||
#define FC_COUNT_EQUAL 1100
|
#define FC_COUNT_EQUAL 1100
|
||||||
|
#define FC_BIN 1200
|
||||||
|
|
||||||
// op-specific constants
|
// op-specific constants
|
||||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||||
|
|
|
||||||
|
|
@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.o1 =*/ { 0 },
|
/*.o1 =*/ { 0 },
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
|
@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||||
|
|
||||||
bool bcast_row = false;
|
|
||||||
|
|
||||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
struct ggml_metal_pipeline_with_params pipeline;
|
struct ggml_metal_pipeline_with_params pipeline;
|
||||||
|
|
||||||
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
|
||||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
|
||||||
|
|
||||||
// src1 is a row
|
|
||||||
GGML_ASSERT(ne11 == 1);
|
|
||||||
|
|
||||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
|
|
||||||
|
|
||||||
bcast_row = true;
|
|
||||||
} else {
|
|
||||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_fuse > 1) {
|
if (n_fuse > 1) {
|
||||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||||
|
|
@ -3015,20 +3002,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pipeline.c4) {
|
||||||
|
args.ne00 = ne00/4;
|
||||||
|
args.ne10 = ne10/4;
|
||||||
|
args.ne0 = ne0/4;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||||
|
|
||||||
if (bcast_row) {
|
if (pipeline.cnt) {
|
||||||
const int64_t n = ggml_nelements(op)/4;
|
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||||
} else {
|
} else {
|
||||||
int nth = 32;
|
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
|
|
||||||
while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
int nth = 1;
|
||||||
|
|
||||||
|
while (2*nth < args.ne0 && nth < nth_max) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -895,11 +895,13 @@ enum ggml_sort_order {
|
||||||
GGML_SORT_ORDER_DESC,
|
GGML_SORT_ORDER_DESC,
|
||||||
};
|
};
|
||||||
|
|
||||||
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
||||||
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||||
// cons: not very efficient
|
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||||
template <int F>
|
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
||||||
kernel void kernel_add_fuse_impl(
|
|
||||||
|
template <typename T0, typename T1, typename T>
|
||||||
|
kernel void kernel_bin_fuse_impl(
|
||||||
constant ggml_metal_kargs_bin & args,
|
constant ggml_metal_kargs_bin & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -907,138 +909,152 @@ kernel void kernel_add_fuse_impl(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
const int i03 = tgpig.z;
|
#define FC_OP FC_bin_op
|
||||||
const int i02 = tgpig.y;
|
#define FC_F FC_bin_f
|
||||||
const int i01 = tgpig.x;
|
#define FC_RB FC_bin_rb
|
||||||
|
|
||||||
const int i13 = i03%args.ne13;
|
if (FC_RB) {
|
||||||
const int i12 = i02%args.ne12;
|
// row broadcast
|
||||||
const int i11 = i01%args.ne11;
|
const uint i0 = tgpig.x;
|
||||||
|
const uint i1 = i0%args.ne10;
|
||||||
|
|
||||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
device const T0 * src0_row = (device const T0 *) (src0);
|
||||||
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
device T * dst_row = (device T *) (dst);
|
||||||
|
|
||||||
device const float * src1_ptr[F];
|
if (FC_F == 1) {
|
||||||
for (short j = 0; j < F; ++j) {
|
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
|
||||||
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
if (FC_OP == 0) {
|
||||||
const int i10 = i0%args.ne10;
|
dst_row[i0] = src0_row[i0] + src1_row[i1];
|
||||||
|
}
|
||||||
|
|
||||||
float res = src0_ptr[i0];
|
if (FC_OP == 1) {
|
||||||
|
dst_row[i0] = src0_row[i0] - src1_row[i1];
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
if (FC_OP == 2) {
|
||||||
for (short j = 0; j < F; ++j) {
|
dst_row[i0] = src0_row[i0] * src1_row[i1];
|
||||||
res += src1_ptr[j][i10];
|
}
|
||||||
}
|
|
||||||
|
|
||||||
dst_ptr[i0] = res;
|
if (FC_OP == 3) {
|
||||||
}
|
dst_row[i0] = src0_row[i0] / src1_row[i1];
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
T0 res = src0_row[i0];
|
||||||
|
|
||||||
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
if (FC_OP == 0) {
|
||||||
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
|
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
if (FC_OP == 1) {
|
||||||
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||||
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
}
|
||||||
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
}
|
||||||
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
|
||||||
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
|
||||||
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
|
||||||
|
|
||||||
kernel void kernel_sub_fuse_1(
|
if (FC_OP == 2) {
|
||||||
constant ggml_metal_kargs_bin & args,
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
device const char * src0,
|
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||||
device const char * src1,
|
}
|
||||||
device char * dst,
|
}
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int i03 = tgpig.z;
|
|
||||||
const int i02 = tgpig.y;
|
|
||||||
const int i01 = tgpig.x;
|
|
||||||
|
|
||||||
const int i13 = i03%args.ne13;
|
if (FC_OP == 3) {
|
||||||
const int i12 = i02%args.ne12;
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
const int i11 = i01%args.ne11;
|
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
dst_row[i0] = res;
|
||||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
||||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
||||||
|
|
||||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
||||||
const int i10 = i0%args.ne10;
|
|
||||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_mul_fuse_1(
|
|
||||||
constant ggml_metal_kargs_bin & args,
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int i03 = tgpig.z;
|
|
||||||
const int i02 = tgpig.y;
|
|
||||||
const int i01 = tgpig.x;
|
|
||||||
|
|
||||||
const int i13 = i03%args.ne13;
|
|
||||||
const int i12 = i02%args.ne12;
|
|
||||||
const int i11 = i01%args.ne11;
|
|
||||||
|
|
||||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
|
||||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
||||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
||||||
|
|
||||||
if (args.ne10 == 1) {
|
|
||||||
const float x = *((device float *)(src1_ptr));
|
|
||||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
||||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
const int i03 = tgpig.z;
|
||||||
const int i10 = i0%args.ne10;
|
const int i02 = tgpig.y;
|
||||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
const int i01 = tgpig.x;
|
||||||
|
|
||||||
|
if (i01 >= args.ne01) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i13 = i03%args.ne13;
|
||||||
|
const int i12 = i02%args.ne12;
|
||||||
|
const int i11 = i01%args.ne11;
|
||||||
|
|
||||||
|
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||||
|
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||||
|
|
||||||
|
if (FC_F == 1) {
|
||||||
|
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||||
|
|
||||||
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||||
|
const int i10 = i0%args.ne10;
|
||||||
|
|
||||||
|
if (FC_OP == 0) {
|
||||||
|
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_OP == 1) {
|
||||||
|
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_OP == 2) {
|
||||||
|
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_OP == 3) {
|
||||||
|
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device const T1 * src1_ptr[8];
|
||||||
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
|
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||||
|
const int i10 = i0%args.ne10;
|
||||||
|
|
||||||
|
T res = src0_ptr[i0];
|
||||||
|
|
||||||
|
if (FC_OP == 0) {
|
||||||
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
|
res += src1_ptr[j][i10];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_OP == 1) {
|
||||||
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
|
res -= src1_ptr[j][i10];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_OP == 2) {
|
||||||
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
|
res *= src1_ptr[j][i10];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_OP == 3) {
|
||||||
|
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||||
|
res /= src1_ptr[j][i10];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_ptr[i0] = res;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef FC_OP
|
||||||
|
#undef FC_F
|
||||||
|
#undef FC_RB
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_div_fuse_1(
|
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
||||||
constant ggml_metal_kargs_bin & args,
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int i03 = tgpig.z;
|
|
||||||
const int i02 = tgpig.y;
|
|
||||||
const int i01 = tgpig.x;
|
|
||||||
|
|
||||||
const int i13 = i03%args.ne13;
|
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
||||||
const int i12 = i02%args.ne12;
|
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
|
||||||
const int i11 = i01%args.ne11;
|
|
||||||
|
|
||||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
|
||||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
||||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
||||||
|
|
||||||
if (args.ne10 == 1) {
|
|
||||||
const float x = 1.0f / *((device float *)(src1_ptr));
|
|
||||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
||||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
||||||
const int i10 = i0%args.ne10;
|
|
||||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_add_id(
|
kernel void kernel_add_id(
|
||||||
constant ggml_metal_kargs_add_id & args,
|
constant ggml_metal_kargs_add_id & args,
|
||||||
|
|
@ -1057,7 +1073,7 @@ kernel void kernel_add_id(
|
||||||
const size_t nb1 = args.ne0 * sizeof(float);
|
const size_t nb1 = args.ne0 * sizeof(float);
|
||||||
const size_t nb2 = args.ne1 * nb1;
|
const size_t nb2 = args.ne1 * nb1;
|
||||||
|
|
||||||
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
||||||
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
||||||
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
||||||
|
|
||||||
|
|
@ -1098,141 +1114,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
|
||||||
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
||||||
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
||||||
|
|
||||||
// assumption: src1 is a row
|
|
||||||
// broadcast src1 into src0
|
|
||||||
template <short F>
|
|
||||||
kernel void kernel_add_row_c4_fuse_impl(
|
|
||||||
constant ggml_metal_kargs_bin & args,
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
|
||||||
const uint nb = args.ne00/4;
|
|
||||||
const uint i = tpig % nb;
|
|
||||||
|
|
||||||
device const float4 * src0_row = (device const float4 *) (src0);
|
|
||||||
device float4 * dst_row = (device float4 *) (dst);
|
|
||||||
|
|
||||||
float4 res = src0_row[tpig];
|
|
||||||
|
|
||||||
#pragma unroll(F)
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
res += ((device const float4 *) (src1 + args.o1[j]))[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_row[tpig] = res;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
|
||||||
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
|
|
||||||
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
|
|
||||||
|
|
||||||
template <short F>
|
|
||||||
kernel void kernel_sub_row_c4_fuse_impl(
|
|
||||||
constant ggml_metal_kargs_bin & args,
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
const uint nb = args.ne00/4;
|
|
||||||
const uint i = tpig % nb;
|
|
||||||
|
|
||||||
device const float4 * src0_row = (device const float4 *) (src0);
|
|
||||||
device float4 * dst_row = (device float4 *) (dst);
|
|
||||||
|
|
||||||
device const float4 * src1_row[F];
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
||||||
}
|
|
||||||
|
|
||||||
float4 res = src0_row[tpig];
|
|
||||||
|
|
||||||
#pragma unroll(F)
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
res -= src1_row[j][i];
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_row[tpig] = res;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
|
||||||
|
|
||||||
template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
|
||||||
|
|
||||||
template <short F>
|
|
||||||
kernel void kernel_mul_row_c4_fuse_impl(
|
|
||||||
constant ggml_metal_kargs_bin & args,
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
const uint nb = args.ne00/4;
|
|
||||||
const uint i = tpig % nb;
|
|
||||||
|
|
||||||
device const float4 * src0_row = (device const float4 *) (src0);
|
|
||||||
device float4 * dst_row = (device float4 *) (dst);
|
|
||||||
|
|
||||||
device const float4 * src1_row[F];
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
||||||
}
|
|
||||||
|
|
||||||
float4 res = src0_row[tpig];
|
|
||||||
|
|
||||||
#pragma unroll(F)
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
res *= src1_row[j][i];
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_row[tpig] = res;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
|
||||||
|
|
||||||
template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
|
||||||
|
|
||||||
template <short F>
|
|
||||||
kernel void kernel_div_row_c4_fuse_impl(
|
|
||||||
constant ggml_metal_kargs_bin & args,
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
const uint nb = args.ne00/4;
|
|
||||||
const uint i = tpig % nb;
|
|
||||||
|
|
||||||
device const float4 * src0_row = (device const float4 *) (src0);
|
|
||||||
device float4 * dst_row = (device float4 *) (dst);
|
|
||||||
|
|
||||||
device const float4 * src1_row[F];
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
||||||
}
|
|
||||||
|
|
||||||
float4 res = src0_row[tpig];
|
|
||||||
|
|
||||||
#pragma unroll(F)
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
res /= src1_row[j][i];
|
|
||||||
}
|
|
||||||
|
|
||||||
dst_row[tpig] = res;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
|
||||||
|
|
||||||
template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
|
||||||
|
|
||||||
kernel void kernel_scale_f32(
|
kernel void kernel_scale_f32(
|
||||||
constant ggml_metal_kargs_scale & args,
|
constant ggml_metal_kargs_scale & args,
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
|
|
||||||
|
|
@ -119,27 +119,48 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
static void usage(const char * executable) {
|
static void usage(const char * executable) {
|
||||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
|
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
|
||||||
printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file] [--prune-layers] [--keep-split] [--override-kv]\n");
|
printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file]\n");
|
||||||
|
printf(" [--prune-layers] [--keep-split] [--override-kv]\n");
|
||||||
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
||||||
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
printf(" --allow-requantize\n");
|
||||||
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
printf(" allow requantizing tensors that have already been quantized\n");
|
||||||
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
printf(" WARNING: this can severely reduce quality compared to quantizing\n");
|
||||||
printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n");
|
printf(" from 16bit or 32bit!\n");
|
||||||
printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
printf(" --leave-output-tensor\n");
|
||||||
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
printf(" leave output.weight un(re)quantized\n");
|
||||||
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
printf(" increases model size but may also increase quality, especially when requantizing\n");
|
||||||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
printf(" --pure\n");
|
||||||
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
|
printf(" disable k-quant mixtures and quantize all tensors to the same type\n");
|
||||||
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
|
printf(" --imatrix file_name\n");
|
||||||
printf(" --tensor-type-file tensor_type.txt: list of tensors to quantize to specific ggml_type. example: --tensor-type-file tensor_type_list.txt\n");
|
printf(" use data in file_name as importance matrix for quant optimizations\n");
|
||||||
printf(" Advanced option to selectively quantize a long list of tensors. Format to be tensor_name=ggml_type, separated by spaces/newline.\n");
|
printf(" --include-weights tensor_name\n");
|
||||||
printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n");
|
printf(" use importance matrix for this/these tensor(s)\n");
|
||||||
printf(" Advanced option to remove all tensors from the given layers\n");
|
printf(" --exclude-weights tensor_name\n");
|
||||||
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
printf(" do not use importance matrix for this/these tensor(s)\n");
|
||||||
|
printf(" --output-tensor-type ggml_type\n");
|
||||||
|
printf(" use this ggml_type for the output.weight tensor\n");
|
||||||
|
printf(" --token-embedding-type ggml_type\n");
|
||||||
|
printf(" use this ggml_type for the token embeddings tensor\n");
|
||||||
|
printf(" --tensor-type tensor_name=ggml_type\n");
|
||||||
|
printf(" quantize this tensor to this ggml_type\n");
|
||||||
|
printf(" this is an advanced option to selectively quantize tensors. may be specified multiple times.\n");
|
||||||
|
printf(" example: --tensor-type attn_q=q8_0\n");
|
||||||
|
printf(" --tensor-type-file tensor_types.txt\n");
|
||||||
|
printf(" list of tensors to quantize to a specific ggml_type\n");
|
||||||
|
printf(" this is an advanced option to selectively quantize a long list of tensors.\n");
|
||||||
|
printf(" the file should use the same format as above, separated by spaces or newlines.\n");
|
||||||
|
printf(" --prune-layers L0,L1,L2...\n");
|
||||||
|
printf(" comma-separated list of layer numbers to prune from the model\n");
|
||||||
|
printf(" WARNING: this is an advanced option, use with care.\n");
|
||||||
|
printf(" --keep-split\n");
|
||||||
|
printf(" generate quantized model in the same shards as input\n");
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
printf(" override model metadata by key in the quantized model. may be specified multiple times.\n");
|
||||||
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
printf(" WARNING: this is an advanced option, use with care.\n\n");
|
||||||
printf("\nAllowed quantization types:\n");
|
printf("note: --include-weights and --exclude-weights cannot be used together\n\n");
|
||||||
|
printf("-----------------------------------------------------------------------------\n");
|
||||||
|
printf(" allowed quantization types\n");
|
||||||
|
printf("-----------------------------------------------------------------------------\n\n");
|
||||||
for (const auto & it : QUANT_OPTIONS) {
|
for (const auto & it : QUANT_OPTIONS) {
|
||||||
if (it.name != "COPY") {
|
if (it.name != "COPY") {
|
||||||
printf(" %2d or ", it.ftype);
|
printf(" %2d or ", it.ftype);
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,7 @@
|
||||||
#if defined(_MSC_VER)
|
|
||||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "ggml-rpc.h"
|
#include "ggml-rpc.h"
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
# define NOMINMAX
|
# define NOMINMAX
|
||||||
# define DIRECTORY_SEPARATOR '\\'
|
# define DIRECTORY_SEPARATOR '\\'
|
||||||
# include <locale>
|
|
||||||
# include <windows.h>
|
# include <windows.h>
|
||||||
# include <fcntl.h>
|
# include <fcntl.h>
|
||||||
# include <io.h>
|
# include <io.h>
|
||||||
|
|
@ -15,23 +10,43 @@
|
||||||
# include <unistd.h>
|
# include <unistd.h>
|
||||||
# include <sys/stat.h>
|
# include <sys/stat.h>
|
||||||
#endif
|
#endif
|
||||||
#include <codecvt>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <filesystem>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
#if defined(__linux__)
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <pwd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||||
|
#ifdef _WIN32
|
||||||
|
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||||
|
if (str.empty()) {
|
||||||
|
return std::wstring();
|
||||||
|
}
|
||||||
|
|
||||||
|
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||||
|
|
||||||
|
if (size <= 0) {
|
||||||
|
return std::wstring();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::wstring wstr(size, 0);
|
||||||
|
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||||
|
|
||||||
|
return wstr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||||
// returns true if successful, false otherwise
|
// returns true if successful, false otherwise
|
||||||
static bool fs_create_directory_with_parents(const std::string & path) {
|
static bool fs_create_directory_with_parents(const std::string & path) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
std::wstring wpath = utf8_to_wstring(path);
|
||||||
std::wstring wpath = converter.from_bytes(path);
|
|
||||||
|
|
||||||
// if the path already exists, check whether it's a directory
|
// if the path already exists, check whether it's a directory
|
||||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||||
|
|
@ -44,9 +59,16 @@ static bool fs_create_directory_with_parents(const std::string & path) {
|
||||||
// process path from front to back, procedurally creating directories
|
// process path from front to back, procedurally creating directories
|
||||||
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
||||||
const std::wstring subpath = wpath.substr(0, pos_slash);
|
const std::wstring subpath = wpath.substr(0, pos_slash);
|
||||||
const wchar_t * test = subpath.c_str();
|
|
||||||
|
|
||||||
const bool success = CreateDirectoryW(test, NULL);
|
pos_slash += 1;
|
||||||
|
|
||||||
|
// skip the drive letter, in some systems it can return an access denied error
|
||||||
|
if (subpath.length() == 2 && subpath[1] == ':') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool success = CreateDirectoryW(subpath.c_str(), NULL);
|
||||||
|
|
||||||
if (!success) {
|
if (!success) {
|
||||||
const DWORD error = GetLastError();
|
const DWORD error = GetLastError();
|
||||||
|
|
||||||
|
|
@ -60,8 +82,6 @@ static bool fs_create_directory_with_parents(const std::string & path) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pos_slash += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -115,13 +135,27 @@ static std::string fs_get_cache_directory() {
|
||||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||||
if (std::getenv("XDG_CACHE_HOME")) {
|
if (std::getenv("XDG_CACHE_HOME")) {
|
||||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||||
} else {
|
} else if (std::getenv("HOME")) {
|
||||||
cache_directory = std::getenv("HOME") + std::string("/.cache/");
|
cache_directory = std::getenv("HOME") + std::string("/.cache/");
|
||||||
|
} else {
|
||||||
|
#if defined(__linux__)
|
||||||
|
/* no $HOME is defined, fallback to getpwuid */
|
||||||
|
struct passwd *pw = getpwuid(getuid());
|
||||||
|
if ((!pw) || (!pw->pw_dir)) {
|
||||||
|
throw std::runtime_error("Failed to find $HOME directory");
|
||||||
|
}
|
||||||
|
|
||||||
|
cache_directory = std::string(pw->pw_dir) + std::string("/.cache/");
|
||||||
|
#else /* defined(__linux__) */
|
||||||
|
throw std::runtime_error("Failed to find $HOME directory");
|
||||||
|
#endif /* defined(__linux__) */
|
||||||
}
|
}
|
||||||
#elif defined(__APPLE__)
|
#elif defined(__APPLE__)
|
||||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
cache_directory = std::getenv("LOCALAPPDATA");
|
cache_directory = std::getenv("LOCALAPPDATA");
|
||||||
|
#elif defined(__EMSCRIPTEN__)
|
||||||
|
GGML_ABORT("not implemented on this platform");
|
||||||
#else
|
#else
|
||||||
# error Unknown architecture
|
# error Unknown architecture
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -2507,7 +2507,8 @@ private:
|
||||||
slot.n_prompt_tokens_processed++;
|
slot.n_prompt_tokens_processed++;
|
||||||
|
|
||||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||||
if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
|
const int n_last = std::min(n_batch, 512);
|
||||||
|
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue