Compare commits
19 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
4301e27319 | |
|
|
a2c199e479 | |
|
|
15dd67d869 | |
|
|
bde461de8c | |
|
|
8faa87db02 | |
|
|
6f1f6a961a | |
|
|
669696e00d | |
|
|
982060fadc | |
|
|
6853bee680 | |
|
|
487674fbb3 | |
|
|
acec774ef6 | |
|
|
5c0d18881e | |
|
|
4b2a4778f8 | |
|
|
58062860af | |
|
|
2973a65ecb | |
|
|
d0794e89d9 | |
|
|
9dcac6cf9f | |
|
|
0e49a7b8b4 | |
|
|
4164596c76 |
|
|
@ -86,6 +86,7 @@ body:
|
||||||
description: >
|
description: >
|
||||||
If applicable, please copy and paste any relevant log output, including any generated text.
|
If applicable, please copy and paste any relevant log output, including any generated text.
|
||||||
This will be automatically formatted into code, so no need for backticks.
|
This will be automatically formatted into code, so no need for backticks.
|
||||||
|
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
|
||||||
render: shell
|
render: shell
|
||||||
validations:
|
validations:
|
||||||
required: false
|
required: false
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,10 @@ concurrency:
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
webui-setup:
|
webui-check:
|
||||||
name: WebUI Setup
|
name: WebUI Checks
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
continue-on-error: true
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -42,137 +43,66 @@ jobs:
|
||||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
|
id: node
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
cache: "npm"
|
cache: "npm"
|
||||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||||
|
|
||||||
- name: Cache node_modules
|
|
||||||
uses: actions/cache@v4
|
|
||||||
id: cache-node-modules
|
|
||||||
with:
|
|
||||||
path: tools/server/webui/node_modules
|
|
||||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-node-modules-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
id: setup
|
||||||
|
if: ${{ steps.node.conclusion == 'success' }}
|
||||||
run: npm ci
|
run: npm ci
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
webui-check:
|
|
||||||
needs: webui-setup
|
|
||||||
name: WebUI Check
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22"
|
|
||||||
|
|
||||||
- name: Restore node_modules cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: tools/server/webui/node_modules
|
|
||||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-node-modules-
|
|
||||||
|
|
||||||
- name: Run type checking
|
- name: Run type checking
|
||||||
|
if: ${{ always() && steps.setup.conclusion == 'success' }}
|
||||||
run: npm run check
|
run: npm run check
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run linting
|
- name: Run linting
|
||||||
|
if: ${{ always() && steps.setup.conclusion == 'success' }}
|
||||||
run: npm run lint
|
run: npm run lint
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
webui-build:
|
|
||||||
needs: webui-check
|
|
||||||
name: WebUI Build
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22"
|
|
||||||
|
|
||||||
- name: Restore node_modules cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: tools/server/webui/node_modules
|
|
||||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-node-modules-
|
|
||||||
|
|
||||||
- name: Build application
|
- name: Build application
|
||||||
|
if: ${{ always() && steps.setup.conclusion == 'success' }}
|
||||||
run: npm run build
|
run: npm run build
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
webui-tests:
|
|
||||||
needs: webui-build
|
|
||||||
name: Run WebUI tests
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22"
|
|
||||||
|
|
||||||
- name: Restore node_modules cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: tools/server/webui/node_modules
|
|
||||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-node-modules-
|
|
||||||
|
|
||||||
- name: Install Playwright browsers
|
- name: Install Playwright browsers
|
||||||
|
id: playwright
|
||||||
|
if: ${{ always() && steps.setup.conclusion == 'success' }}
|
||||||
run: npx playwright install --with-deps
|
run: npx playwright install --with-deps
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Build Storybook
|
- name: Build Storybook
|
||||||
|
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||||
run: npm run build-storybook
|
run: npm run build-storybook
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run Client tests
|
- name: Run Client tests
|
||||||
|
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||||
run: npm run test:client
|
run: npm run test:client
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run Server tests
|
- name: Run Unit tests
|
||||||
run: npm run test:server
|
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||||
|
run: npm run test:unit
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run UI tests
|
- name: Run UI tests
|
||||||
|
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||||
run: npm run test:ui -- --testTimeout=60000
|
run: npm run test:ui -- --testTimeout=60000
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run E2E tests
|
- name: Run E2E tests
|
||||||
|
if: ${{ always() && steps.playwright.conclusion == 'success' }}
|
||||||
run: npm run test:e2e
|
run: npm run test:e2e
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
server-build:
|
server-build:
|
||||||
needs: [webui-tests]
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@
|
||||||
/examples/export-docs/ @ggerganov
|
/examples/export-docs/ @ggerganov
|
||||||
/examples/gen-docs/ @ggerganov
|
/examples/gen-docs/ @ggerganov
|
||||||
/examples/gguf/ @ggerganov
|
/examples/gguf/ @ggerganov
|
||||||
/examples/llama.android/ @ggerganov
|
/examples/llama.android/ @ggerganov @hanyin-arm @naco-siren
|
||||||
/examples/llama.swiftui/ @ggerganov
|
/examples/llama.swiftui/ @ggerganov
|
||||||
/examples/llama.vim @ggerganov
|
/examples/llama.vim @ggerganov
|
||||||
/examples/lookahead/ @ggerganov
|
/examples/lookahead/ @ggerganov
|
||||||
|
|
|
||||||
|
|
@ -190,6 +190,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||||
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
||||||
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
|
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
|
||||||
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
|
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
|
||||||
|
- Android: [llama.android](/examples/llama.android)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
120
common/arg.cpp
120
common/arg.cpp
|
|
@ -420,6 +420,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::set<std::string> seen_args;
|
||||||
|
|
||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
const std::string arg_prefix = "--";
|
const std::string arg_prefix = "--";
|
||||||
|
|
||||||
|
|
@ -430,6 +432,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
||||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||||
}
|
}
|
||||||
|
if (!seen_args.insert(arg).second) {
|
||||||
|
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||||
|
}
|
||||||
auto & tmp = arg_to_options[arg];
|
auto & tmp = arg_to_options[arg];
|
||||||
auto opt = *tmp.first;
|
auto opt = *tmp.first;
|
||||||
bool is_positive = tmp.second;
|
bool is_positive = tmp.second;
|
||||||
|
|
@ -750,6 +755,8 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::set<std::string> seen_args;
|
||||||
|
|
||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
const std::string arg_prefix = "--";
|
const std::string arg_prefix = "--";
|
||||||
|
|
||||||
|
|
@ -760,6 +767,9 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||||
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
||||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||||
}
|
}
|
||||||
|
if (!seen_args.insert(arg).second) {
|
||||||
|
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||||
|
}
|
||||||
auto opt = *arg_to_options[arg];
|
auto opt = *arg_to_options[arg];
|
||||||
std::string val;
|
std::string val;
|
||||||
if (opt.value_hint != nullptr) {
|
if (opt.value_hint != nullptr) {
|
||||||
|
|
@ -1140,7 +1150,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
[](common_params & params) {
|
[](common_params & params) {
|
||||||
params.kv_unified = true;
|
params.kv_unified = true;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--context-shift"},
|
{"--context-shift"},
|
||||||
{"--no-context-shift"},
|
{"--no-context-shift"},
|
||||||
|
|
@ -1226,13 +1236,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION}));
|
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--in-file"}, "FNAME",
|
{"--in-file"}, "FNAME",
|
||||||
"an input file (repeat to specify multiple files)",
|
"an input file (use comma-separated values to specify multiple files)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
std::ifstream file(value);
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
if (!file) {
|
std::ifstream file(item);
|
||||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
|
if (!file) {
|
||||||
|
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
|
||||||
|
}
|
||||||
|
params.in_files.push_back(item);
|
||||||
}
|
}
|
||||||
params.in_files.push_back(value);
|
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1969,9 +1981,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
|
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--image", "--audio"}, "FILE",
|
{"--image", "--audio"}, "FILE",
|
||||||
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
|
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.image.emplace_back(value);
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
|
params.image.emplace_back(item);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI}));
|
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -2218,12 +2232,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--override-kv"}, "KEY=TYPE:VALUE",
|
{"--override-kv"}, "KEY=TYPE:VALUE,...",
|
||||||
"advanced option to override model metadata by key. may be specified multiple times.\n"
|
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
|
||||||
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false",
|
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) {
|
std::vector<std::string> kv_overrides;
|
||||||
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", value.c_str()));
|
|
||||||
|
std::string current;
|
||||||
|
bool escaping = false;
|
||||||
|
|
||||||
|
for (const char c : value) {
|
||||||
|
if (escaping) {
|
||||||
|
current.push_back(c);
|
||||||
|
escaping = false;
|
||||||
|
} else if (c == '\\') {
|
||||||
|
escaping = true;
|
||||||
|
} else if (c == ',') {
|
||||||
|
kv_overrides.push_back(current);
|
||||||
|
current.clear();
|
||||||
|
} else {
|
||||||
|
current.push_back(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (escaping) {
|
||||||
|
current.push_back('\\');
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_overrides.push_back(current);
|
||||||
|
|
||||||
|
for (const auto & kv_override : kv_overrides) {
|
||||||
|
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
|
||||||
|
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
|
|
@ -2237,33 +2278,50 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
));
|
));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--lora"}, "FNAME",
|
{"--lora"}, "FNAME",
|
||||||
"path to LoRA adapter (can be repeated to use multiple adapters)",
|
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
|
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
|
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
|
||||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
|
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--lora-scaled"}, "FNAME", "SCALE",
|
{"--lora-scaled"}, "FNAME:SCALE,...",
|
||||||
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
|
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
|
||||||
[](common_params & params, const std::string & fname, const std::string & scale) {
|
"note: use comma-separated values",
|
||||||
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
|
[](common_params & params, const std::string & value) {
|
||||||
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
|
auto parts = string_split<std::string>(item, ':');
|
||||||
|
if (parts.size() != 2) {
|
||||||
|
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
|
||||||
|
}
|
||||||
|
params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
|
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
|
||||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
|
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--control-vector"}, "FNAME",
|
{"--control-vector"}, "FNAME",
|
||||||
"add a control vector\nnote: this argument can be repeated to add multiple control vectors",
|
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.control_vectors.push_back({ 1.0f, value, });
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
|
params.control_vectors.push_back({ 1.0f, item, });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--control-vector-scaled"}, "FNAME", "SCALE",
|
{"--control-vector-scaled"}, "FNAME:SCALE,...",
|
||||||
"add a control vector with user defined scaling SCALE\n"
|
"add a control vector with user defined scaling SCALE\n"
|
||||||
"note: this argument can be repeated to add multiple scaled control vectors",
|
"note: use comma-separated values (format: FNAME:SCALE,...)",
|
||||||
[](common_params & params, const std::string & fname, const std::string & scale) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.control_vectors.push_back({ std::stof(scale), fname });
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
|
auto parts = string_split<std::string>(item, ':');
|
||||||
|
if (parts.size() != 2) {
|
||||||
|
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
|
||||||
|
}
|
||||||
|
params.control_vectors.push_back({ std::stof(parts[1]), parts[0] });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -2353,13 +2411,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_env("HF_TOKEN"));
|
).set_env("HF_TOKEN"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--context-file"}, "FNAME",
|
{"--context-file"}, "FNAME",
|
||||||
"file to load context from (repeat to specify multiple files)",
|
"file to load context from (use comma-separated values to specify multiple files)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
std::ifstream file(value, std::ios::binary);
|
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||||
if (!file) {
|
std::ifstream file(item, std::ios::binary);
|
||||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
|
if (!file) {
|
||||||
|
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
|
||||||
|
}
|
||||||
|
params.context_files.push_back(item);
|
||||||
}
|
}
|
||||||
params.context_files.push_back(value);
|
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_RETRIEVAL}));
|
).set_examples({LLAMA_EXAMPLE_RETRIEVAL}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
|
||||||
|
|
@ -1092,7 +1092,7 @@ common_init_result::common_init_result(common_params & params) :
|
||||||
auto cparams = common_context_params_to_llama(params);
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
||||||
if (params.fit_params) {
|
if (params.fit_params) {
|
||||||
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||||
|
|
|
||||||
|
|
@ -104,10 +104,9 @@ struct ring_buffer {
|
||||||
struct common_sampler {
|
struct common_sampler {
|
||||||
common_params_sampling params;
|
common_params_sampling params;
|
||||||
|
|
||||||
|
struct llama_sampler * grmr;
|
||||||
struct llama_sampler * chain;
|
struct llama_sampler * chain;
|
||||||
|
|
||||||
bool grammar;
|
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
|
@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
lparams.no_perf = params.no_perf;
|
lparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
llama_sampler * grmr = nullptr;
|
||||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||||
|
|
||||||
bool grammar = false;
|
|
||||||
std::vector<llama_sampler *> samplers;
|
std::vector<llama_sampler *> samplers;
|
||||||
|
|
||||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||||
grammar = true;
|
|
||||||
#else
|
#else
|
||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
|
|
@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
if (!params.grammar.empty()) {
|
if (!params.grammar.empty()) {
|
||||||
if (params.grammar_lazy) {
|
if (params.grammar_lazy) {
|
||||||
samplers.push_back(
|
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||||
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
trigger_tokens.data(), trigger_tokens.size());
|
||||||
trigger_tokens.data(), trigger_tokens.size()));
|
|
||||||
} else {
|
} else {
|
||||||
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
|
/* .grmr = */ grmr,
|
||||||
/* .chain = */ chain,
|
/* .chain = */ chain,
|
||||||
/* .grammar = */ grammar,
|
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
/* .cur_p = */ {},
|
/* .cur_p = */ {},
|
||||||
|
|
@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
|
llama_sampler_free(gsmpl->grmr);
|
||||||
llama_sampler_free(gsmpl->chain);
|
llama_sampler_free(gsmpl->chain);
|
||||||
|
|
||||||
delete gsmpl;
|
delete gsmpl;
|
||||||
|
|
@ -324,25 +320,12 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
const auto tm = gsmpl->tm();
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
if (gsmpl->grammar) {
|
if (gsmpl->grmr && accept_grammar) {
|
||||||
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
|
|
||||||
for (int i = 0; i < n_smpl; i++) {
|
|
||||||
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
||||||
|
|
||||||
// the grammar sampler is always the first one
|
|
||||||
if (i == 0) {
|
|
||||||
if (accept_grammar) {
|
|
||||||
llama_sampler_accept(smpl, token);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
llama_sampler_accept(smpl, token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
llama_sampler_accept(gsmpl->chain, token);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_sampler_accept(gsmpl->chain, token);
|
||||||
|
|
||||||
gsmpl->prev.push_back(token);
|
gsmpl->prev.push_back(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
return new common_sampler {
|
return new common_sampler {
|
||||||
/* .params = */ gsmpl->params,
|
/* .params = */ gsmpl->params,
|
||||||
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
/* .grammar = */ gsmpl->grammar,
|
|
||||||
/* .prev = */ gsmpl->prev,
|
/* .prev = */ gsmpl->prev,
|
||||||
/* .cur = */ gsmpl->cur,
|
/* .cur = */ gsmpl->cur,
|
||||||
/* .cur_p = */ gsmpl->cur_p,
|
/* .cur_p = */ gsmpl->cur_p,
|
||||||
|
|
@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||||
return gsmpl->chain;
|
return gsmpl->chain;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
|
@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
|
|
||||||
llama_token id = LLAMA_TOKEN_NULL;
|
llama_token id = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
auto & grmr = gsmpl->grmr;
|
||||||
auto & chain = gsmpl->chain;
|
auto & chain = gsmpl->chain;
|
||||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||||
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
|
if (grammar_first) {
|
||||||
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
|
id = cur_p.data[cur_p.selected].id;
|
||||||
|
|
||||||
|
if (grammar_first) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
||||||
|
{
|
||||||
|
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||||
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||||
|
|
||||||
|
llama_sampler_apply(grmr, &single_token_data_array);
|
||||||
|
|
||||||
|
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
|
if (is_valid) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resampling:
|
||||||
|
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||||
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
std::vector<llama_token> result;
|
||||||
|
|
@ -440,7 +454,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (; i < draft.size(); i++) {
|
for (; i < draft.size(); i++) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
|
@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i == draft.size()) {
|
if (i == draft.size()) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
|
@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||||
std::vector<int> idxs(draft.size() + 1);
|
std::vector<int> idxs(draft.size() + 1);
|
||||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
idxs[i] = i;
|
idxs[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||||
// - check if the token fits the grammar (if any)
|
// - check if the token fits the grammar (if any)
|
||||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||||
//
|
//
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||||
|
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||||
|
//
|
||||||
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||||
|
|
||||||
// generalized version of common_sampler_sample
|
// generalized version of common_sampler_sample
|
||||||
//
|
//
|
||||||
|
|
@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
//
|
//
|
||||||
// returns at least 1 token, up to idxs.size()
|
// returns at least 1 token, up to idxs.size()
|
||||||
//
|
//
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
|
||||||
for (int i = 0; i < params.n_draft; ++i) {
|
for (int i = 0; i < params.n_draft; ++i) {
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
||||||
common_sampler_sample(smpl, ctx_dft, 0);
|
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||||
|
|
||||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,26 @@
|
||||||
|
|
||||||
# Android
|
# Android
|
||||||
|
|
||||||
|
## Build with Android Studio
|
||||||
|
|
||||||
|
Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
|
||||||
|

|
||||||
|
|
||||||
|
This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
|
||||||
|
It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
|
||||||
|
|
||||||
|
A minimal Android app frontend is included to showcase the binding’s core functionalities:
|
||||||
|
1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`.
|
||||||
|
2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs.
|
||||||
|
3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`.
|
||||||
|
|
||||||
|
For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
|
||||||
|
This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
|
||||||
|
|
||||||
|
|  |  |  |
|
||||||
|
|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
|
||||||
|
| Home screen | System prompt | "Haiku" |
|
||||||
|
|
||||||
## Build on Android using Termux
|
## Build on Android using Termux
|
||||||
|
|
||||||
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.
|
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
plugins {
|
plugins {
|
||||||
id("com.android.application")
|
alias(libs.plugins.android.application)
|
||||||
id("org.jetbrains.kotlin.android")
|
alias(libs.plugins.jetbrains.kotlin.android)
|
||||||
}
|
}
|
||||||
|
|
||||||
android {
|
android {
|
||||||
namespace = "com.example.llama"
|
namespace = "com.example.llama"
|
||||||
compileSdk = 34
|
compileSdk = 36
|
||||||
|
|
||||||
defaultConfig {
|
defaultConfig {
|
||||||
applicationId = "com.example.llama"
|
applicationId = "com.example.llama.aichat"
|
||||||
|
|
||||||
minSdk = 33
|
minSdk = 33
|
||||||
targetSdk = 34
|
targetSdk = 36
|
||||||
|
|
||||||
versionCode = 1
|
versionCode = 1
|
||||||
versionName = "1.0"
|
versionName = "1.0"
|
||||||
|
|
||||||
|
|
@ -21,8 +23,17 @@ android {
|
||||||
}
|
}
|
||||||
|
|
||||||
buildTypes {
|
buildTypes {
|
||||||
|
debug {
|
||||||
|
isMinifyEnabled = true
|
||||||
|
isShrinkResources = true
|
||||||
|
proguardFiles(
|
||||||
|
getDefaultProguardFile("proguard-android.txt"),
|
||||||
|
"proguard-rules.pro"
|
||||||
|
)
|
||||||
|
}
|
||||||
release {
|
release {
|
||||||
isMinifyEnabled = false
|
isMinifyEnabled = true
|
||||||
|
isShrinkResources = true
|
||||||
proguardFiles(
|
proguardFiles(
|
||||||
getDefaultProguardFile("proguard-android-optimize.txt"),
|
getDefaultProguardFile("proguard-android-optimize.txt"),
|
||||||
"proguard-rules.pro"
|
"proguard-rules.pro"
|
||||||
|
|
@ -36,30 +47,15 @@ android {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "1.8"
|
jvmTarget = "1.8"
|
||||||
}
|
}
|
||||||
buildFeatures {
|
|
||||||
compose = true
|
|
||||||
}
|
|
||||||
composeOptions {
|
|
||||||
kotlinCompilerExtensionVersion = "1.5.1"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
implementation(libs.bundles.androidx)
|
||||||
|
implementation(libs.material)
|
||||||
|
|
||||||
implementation("androidx.core:core-ktx:1.12.0")
|
implementation(project(":lib"))
|
||||||
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
|
|
||||||
implementation("androidx.activity:activity-compose:1.8.2")
|
testImplementation(libs.junit)
|
||||||
implementation(platform("androidx.compose:compose-bom:2023.08.00"))
|
androidTestImplementation(libs.androidx.junit)
|
||||||
implementation("androidx.compose.ui:ui")
|
androidTestImplementation(libs.androidx.espresso.core)
|
||||||
implementation("androidx.compose.ui:ui-graphics")
|
|
||||||
implementation("androidx.compose.ui:ui-tooling-preview")
|
|
||||||
implementation("androidx.compose.material3:material3")
|
|
||||||
implementation(project(":llama"))
|
|
||||||
testImplementation("junit:junit:4.13.2")
|
|
||||||
androidTestImplementation("androidx.test.ext:junit:1.1.5")
|
|
||||||
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
|
|
||||||
androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
|
|
||||||
androidTestImplementation("androidx.compose.ui:ui-test-junit4")
|
|
||||||
debugImplementation("androidx.compose.ui:ui-tooling")
|
|
||||||
debugImplementation("androidx.compose.ui:ui-test-manifest")
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,3 +19,11 @@
|
||||||
# If you keep the line number information, uncomment this to
|
# If you keep the line number information, uncomment this to
|
||||||
# hide the original source file name.
|
# hide the original source file name.
|
||||||
#-renamesourcefileattribute SourceFile
|
#-renamesourcefileattribute SourceFile
|
||||||
|
|
||||||
|
-keep class com.arm.aichat.* { *; }
|
||||||
|
-keep class com.arm.aichat.gguf.* { *; }
|
||||||
|
|
||||||
|
-assumenosideeffects class android.util.Log {
|
||||||
|
public static int v(...);
|
||||||
|
public static int d(...);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,21 @@
|
||||||
<?xml version="1.0" encoding="utf-8"?>
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||||
xmlns:tools="http://schemas.android.com/tools">
|
|
||||||
|
|
||||||
<uses-permission android:name="android.permission.INTERNET" />
|
|
||||||
|
|
||||||
<application
|
<application
|
||||||
android:allowBackup="true"
|
android:allowBackup="true"
|
||||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||||
|
android:extractNativeLibs="true"
|
||||||
android:fullBackupContent="@xml/backup_rules"
|
android:fullBackupContent="@xml/backup_rules"
|
||||||
android:icon="@mipmap/ic_launcher"
|
android:icon="@mipmap/ic_launcher_round"
|
||||||
android:label="@string/app_name"
|
android:label="@string/app_name"
|
||||||
android:roundIcon="@mipmap/ic_launcher_round"
|
android:roundIcon="@mipmap/ic_launcher_round"
|
||||||
android:supportsRtl="true"
|
android:supportsRtl="true"
|
||||||
android:theme="@style/Theme.LlamaAndroid"
|
android:theme="@style/Theme.AiChatSample"
|
||||||
>
|
>
|
||||||
|
|
||||||
<activity
|
<activity
|
||||||
android:name=".MainActivity"
|
android:name=".MainActivity"
|
||||||
android:exported="true"
|
android:exported="true">
|
||||||
android:theme="@style/Theme.LlamaAndroid">
|
|
||||||
<intent-filter>
|
<intent-filter>
|
||||||
<action android:name="android.intent.action.MAIN" />
|
<action android:name="android.intent.action.MAIN" />
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,119 +0,0 @@
|
||||||
package com.example.llama
|
|
||||||
|
|
||||||
import android.app.DownloadManager
|
|
||||||
import android.net.Uri
|
|
||||||
import android.util.Log
|
|
||||||
import androidx.compose.material3.Button
|
|
||||||
import androidx.compose.material3.Text
|
|
||||||
import androidx.compose.runtime.Composable
|
|
||||||
import androidx.compose.runtime.getValue
|
|
||||||
import androidx.compose.runtime.mutableDoubleStateOf
|
|
||||||
import androidx.compose.runtime.mutableStateOf
|
|
||||||
import androidx.compose.runtime.remember
|
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
|
||||||
import androidx.compose.runtime.setValue
|
|
||||||
import androidx.core.database.getLongOrNull
|
|
||||||
import androidx.core.net.toUri
|
|
||||||
import kotlinx.coroutines.delay
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
data class Downloadable(val name: String, val source: Uri, val destination: File) {
|
|
||||||
companion object {
|
|
||||||
@JvmStatic
|
|
||||||
private val tag: String? = this::class.qualifiedName
|
|
||||||
|
|
||||||
sealed interface State
|
|
||||||
data object Ready: State
|
|
||||||
data class Downloading(val id: Long): State
|
|
||||||
data class Downloaded(val downloadable: Downloadable): State
|
|
||||||
data class Error(val message: String): State
|
|
||||||
|
|
||||||
@JvmStatic
|
|
||||||
@Composable
|
|
||||||
fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) {
|
|
||||||
var status: State by remember {
|
|
||||||
mutableStateOf(
|
|
||||||
if (item.destination.exists()) Downloaded(item)
|
|
||||||
else Ready
|
|
||||||
)
|
|
||||||
}
|
|
||||||
var progress by remember { mutableDoubleStateOf(0.0) }
|
|
||||||
|
|
||||||
val coroutineScope = rememberCoroutineScope()
|
|
||||||
|
|
||||||
suspend fun waitForDownload(result: Downloading, item: Downloadable): State {
|
|
||||||
while (true) {
|
|
||||||
val cursor = dm.query(DownloadManager.Query().setFilterById(result.id))
|
|
||||||
|
|
||||||
if (cursor == null) {
|
|
||||||
Log.e(tag, "dm.query() returned null")
|
|
||||||
return Error("dm.query() returned null")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!cursor.moveToFirst() || cursor.count < 1) {
|
|
||||||
cursor.close()
|
|
||||||
Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?")
|
|
||||||
return Ready
|
|
||||||
}
|
|
||||||
|
|
||||||
val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR)
|
|
||||||
val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES)
|
|
||||||
val sofar = cursor.getLongOrNull(pix) ?: 0
|
|
||||||
val total = cursor.getLongOrNull(tix) ?: 1
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
if (sofar == total) {
|
|
||||||
return Downloaded(item)
|
|
||||||
}
|
|
||||||
|
|
||||||
progress = (sofar * 1.0) / total
|
|
||||||
|
|
||||||
delay(1000L)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun onClick() {
|
|
||||||
when (val s = status) {
|
|
||||||
is Downloaded -> {
|
|
||||||
viewModel.load(item.destination.path)
|
|
||||||
}
|
|
||||||
|
|
||||||
is Downloading -> {
|
|
||||||
coroutineScope.launch {
|
|
||||||
status = waitForDownload(s, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> {
|
|
||||||
item.destination.delete()
|
|
||||||
|
|
||||||
val request = DownloadManager.Request(item.source).apply {
|
|
||||||
setTitle("Downloading model")
|
|
||||||
setDescription("Downloading model: ${item.name}")
|
|
||||||
setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI)
|
|
||||||
setDestinationUri(item.destination.toUri())
|
|
||||||
}
|
|
||||||
|
|
||||||
viewModel.log("Saving ${item.name} to ${item.destination.path}")
|
|
||||||
Log.i(tag, "Saving ${item.name} to ${item.destination.path}")
|
|
||||||
|
|
||||||
val id = dm.enqueue(request)
|
|
||||||
status = Downloading(id)
|
|
||||||
onClick()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Button(onClick = { onClick() }, enabled = status !is Downloading) {
|
|
||||||
when (status) {
|
|
||||||
is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%")
|
|
||||||
is Downloaded -> Text("Load ${item.name}")
|
|
||||||
is Ready -> Text("Download ${item.name}")
|
|
||||||
is Error -> Text("Download ${item.name}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,154 +1,257 @@
|
||||||
package com.example.llama
|
package com.example.llama
|
||||||
|
|
||||||
import android.app.ActivityManager
|
|
||||||
import android.app.DownloadManager
|
|
||||||
import android.content.ClipData
|
|
||||||
import android.content.ClipboardManager
|
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.os.Bundle
|
import android.os.Bundle
|
||||||
import android.os.StrictMode
|
import android.util.Log
|
||||||
import android.os.StrictMode.VmPolicy
|
import android.widget.EditText
|
||||||
import android.text.format.Formatter
|
import android.widget.TextView
|
||||||
import androidx.activity.ComponentActivity
|
import android.widget.Toast
|
||||||
import androidx.activity.compose.setContent
|
import androidx.activity.enableEdgeToEdge
|
||||||
import androidx.activity.viewModels
|
import androidx.activity.result.contract.ActivityResultContracts
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.appcompat.app.AppCompatActivity
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.lifecycle.lifecycleScope
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.recyclerview.widget.LinearLayoutManager
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.recyclerview.widget.RecyclerView
|
||||||
import androidx.compose.foundation.layout.padding
|
import com.arm.aichat.AiChat
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import com.arm.aichat.InferenceEngine
|
||||||
import androidx.compose.foundation.lazy.items
|
import com.arm.aichat.gguf.GgufMetadata
|
||||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
import com.arm.aichat.gguf.GgufMetadataReader
|
||||||
import androidx.compose.material3.Button
|
import com.google.android.material.floatingactionbutton.FloatingActionButton
|
||||||
import androidx.compose.material3.LocalContentColor
|
import kotlinx.coroutines.Dispatchers
|
||||||
import androidx.compose.material3.MaterialTheme
|
import kotlinx.coroutines.flow.onCompletion
|
||||||
import androidx.compose.material3.OutlinedTextField
|
import kotlinx.coroutines.launch
|
||||||
import androidx.compose.material3.Surface
|
import kotlinx.coroutines.withContext
|
||||||
import androidx.compose.material3.Text
|
|
||||||
import androidx.compose.runtime.Composable
|
|
||||||
import androidx.compose.ui.Modifier
|
|
||||||
import androidx.compose.ui.unit.dp
|
|
||||||
import androidx.core.content.getSystemService
|
|
||||||
import com.example.llama.ui.theme.LlamaAndroidTheme
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
import java.io.FileOutputStream
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.util.UUID
|
||||||
|
|
||||||
class MainActivity(
|
class MainActivity : AppCompatActivity() {
|
||||||
activityManager: ActivityManager? = null,
|
|
||||||
downloadManager: DownloadManager? = null,
|
|
||||||
clipboardManager: ClipboardManager? = null,
|
|
||||||
): ComponentActivity() {
|
|
||||||
private val tag: String? = this::class.simpleName
|
|
||||||
|
|
||||||
private val activityManager by lazy { activityManager ?: getSystemService<ActivityManager>()!! }
|
// Android views
|
||||||
private val downloadManager by lazy { downloadManager ?: getSystemService<DownloadManager>()!! }
|
private lateinit var ggufTv: TextView
|
||||||
private val clipboardManager by lazy { clipboardManager ?: getSystemService<ClipboardManager>()!! }
|
private lateinit var messagesRv: RecyclerView
|
||||||
|
private lateinit var userInputEt: EditText
|
||||||
|
private lateinit var userActionFab: FloatingActionButton
|
||||||
|
|
||||||
private val viewModel: MainViewModel by viewModels()
|
// Arm AI Chat inference engine
|
||||||
|
private lateinit var engine: InferenceEngine
|
||||||
|
|
||||||
// Get a MemoryInfo object for the device's current memory status.
|
// Conversation states
|
||||||
private fun availableMemory(): ActivityManager.MemoryInfo {
|
private var isModelReady = false
|
||||||
return ActivityManager.MemoryInfo().also { memoryInfo ->
|
private val messages = mutableListOf<Message>()
|
||||||
activityManager.getMemoryInfo(memoryInfo)
|
private val lastAssistantMsg = StringBuilder()
|
||||||
}
|
private val messageAdapter = MessageAdapter(messages)
|
||||||
}
|
|
||||||
|
|
||||||
override fun onCreate(savedInstanceState: Bundle?) {
|
override fun onCreate(savedInstanceState: Bundle?) {
|
||||||
super.onCreate(savedInstanceState)
|
super.onCreate(savedInstanceState)
|
||||||
|
enableEdgeToEdge()
|
||||||
|
setContentView(R.layout.activity_main)
|
||||||
|
|
||||||
StrictMode.setVmPolicy(
|
// Find views
|
||||||
VmPolicy.Builder(StrictMode.getVmPolicy())
|
ggufTv = findViewById(R.id.gguf)
|
||||||
.detectLeakedClosableObjects()
|
messagesRv = findViewById(R.id.messages)
|
||||||
.build()
|
messagesRv.layoutManager = LinearLayoutManager(this)
|
||||||
)
|
messagesRv.adapter = messageAdapter
|
||||||
|
userInputEt = findViewById(R.id.user_input)
|
||||||
|
userActionFab = findViewById(R.id.fab)
|
||||||
|
|
||||||
val free = Formatter.formatFileSize(this, availableMemory().availMem)
|
// Arm AI Chat initialization
|
||||||
val total = Formatter.formatFileSize(this, availableMemory().totalMem)
|
lifecycleScope.launch(Dispatchers.Default) {
|
||||||
|
engine = AiChat.getInferenceEngine(applicationContext)
|
||||||
viewModel.log("Current memory: $free / $total")
|
}
|
||||||
viewModel.log("Downloads directory: ${getExternalFilesDir(null)}")
|
|
||||||
|
|
||||||
val extFilesDir = getExternalFilesDir(null)
|
|
||||||
|
|
||||||
val models = listOf(
|
|
||||||
Downloadable(
|
|
||||||
"Phi-2 7B (Q4_0, 1.6 GiB)",
|
|
||||||
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"),
|
|
||||||
File(extFilesDir, "phi-2-q4_0.gguf"),
|
|
||||||
),
|
|
||||||
Downloadable(
|
|
||||||
"TinyLlama 1.1B (f16, 2.2 GiB)",
|
|
||||||
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"),
|
|
||||||
File(extFilesDir, "tinyllama-1.1-f16.gguf"),
|
|
||||||
),
|
|
||||||
Downloadable(
|
|
||||||
"Phi 2 DPO (Q3_K_M, 1.48 GiB)",
|
|
||||||
Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"),
|
|
||||||
File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
setContent {
|
|
||||||
LlamaAndroidTheme {
|
|
||||||
// A surface container using the 'background' color from the theme
|
|
||||||
Surface(
|
|
||||||
modifier = Modifier.fillMaxSize(),
|
|
||||||
color = MaterialTheme.colorScheme.background
|
|
||||||
) {
|
|
||||||
MainCompose(
|
|
||||||
viewModel,
|
|
||||||
clipboardManager,
|
|
||||||
downloadManager,
|
|
||||||
models,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Upon CTA button tapped
|
||||||
|
userActionFab.setOnClickListener {
|
||||||
|
if (isModelReady) {
|
||||||
|
// If model is ready, validate input and send to engine
|
||||||
|
handleUserInput()
|
||||||
|
} else {
|
||||||
|
// Otherwise, prompt user to select a GGUF metadata on the device
|
||||||
|
getContent.launch(arrayOf("*/*"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Composable
|
private val getContent = registerForActivityResult(
|
||||||
fun MainCompose(
|
ActivityResultContracts.OpenDocument()
|
||||||
viewModel: MainViewModel,
|
) { uri ->
|
||||||
clipboard: ClipboardManager,
|
Log.i(TAG, "Selected file uri:\n $uri")
|
||||||
dm: DownloadManager,
|
uri?.let { handleSelectedModel(it) }
|
||||||
models: List<Downloadable>
|
}
|
||||||
) {
|
|
||||||
Column {
|
|
||||||
val scrollState = rememberLazyListState()
|
|
||||||
|
|
||||||
Box(modifier = Modifier.weight(1f)) {
|
/**
|
||||||
LazyColumn(state = scrollState) {
|
* Handles the file Uri from [getContent] result
|
||||||
items(viewModel.messages) {
|
*/
|
||||||
Text(
|
private fun handleSelectedModel(uri: Uri) {
|
||||||
it,
|
// Update UI states
|
||||||
style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current),
|
userActionFab.isEnabled = false
|
||||||
modifier = Modifier.padding(16.dp)
|
userInputEt.hint = "Parsing GGUF..."
|
||||||
)
|
ggufTv.text = "Parsing metadata from selected file \n$uri"
|
||||||
|
|
||||||
|
lifecycleScope.launch(Dispatchers.IO) {
|
||||||
|
// Parse GGUF metadata
|
||||||
|
Log.i(TAG, "Parsing GGUF metadata...")
|
||||||
|
contentResolver.openInputStream(uri)?.use {
|
||||||
|
GgufMetadataReader.create().readStructuredMetadata(it)
|
||||||
|
}?.let { metadata ->
|
||||||
|
// Update UI to show GGUF metadata to user
|
||||||
|
Log.i(TAG, "GGUF parsed: \n$metadata")
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
ggufTv.text = metadata.toString()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
OutlinedTextField(
|
|
||||||
value = viewModel.message,
|
|
||||||
onValueChange = { viewModel.updateMessage(it) },
|
|
||||||
label = { Text("Message") },
|
|
||||||
)
|
|
||||||
Row {
|
|
||||||
Button({ viewModel.send() }) { Text("Send") }
|
|
||||||
Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") }
|
|
||||||
Button({ viewModel.clear() }) { Text("Clear") }
|
|
||||||
Button({
|
|
||||||
viewModel.messages.joinToString("\n").let {
|
|
||||||
clipboard.setPrimaryClip(ClipData.newPlainText("", it))
|
|
||||||
}
|
|
||||||
}) { Text("Copy") }
|
|
||||||
}
|
|
||||||
|
|
||||||
Column {
|
// Ensure the model file is available
|
||||||
for (model in models) {
|
val modelName = metadata.filename() + FILE_EXTENSION_GGUF
|
||||||
Downloadable.Button(viewModel, dm, model)
|
contentResolver.openInputStream(uri)?.use { input ->
|
||||||
|
ensureModelFile(modelName, input)
|
||||||
|
}?.let { modelFile ->
|
||||||
|
loadModel(modelName, modelFile)
|
||||||
|
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
isModelReady = true
|
||||||
|
userInputEt.hint = "Type and send a message!"
|
||||||
|
userInputEt.isEnabled = true
|
||||||
|
userActionFab.setImageResource(R.drawable.outline_send_24)
|
||||||
|
userActionFab.isEnabled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prepare the model file within app's private storage
|
||||||
|
*/
|
||||||
|
private suspend fun ensureModelFile(modelName: String, input: InputStream) =
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
File(ensureModelsDirectory(), modelName).also { file ->
|
||||||
|
// Copy the file into local storage if not yet done
|
||||||
|
if (!file.exists()) {
|
||||||
|
Log.i(TAG, "Start copying file to $modelName")
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
userInputEt.hint = "Copying file..."
|
||||||
|
}
|
||||||
|
|
||||||
|
FileOutputStream(file).use { input.copyTo(it) }
|
||||||
|
Log.i(TAG, "Finished copying file to $modelName")
|
||||||
|
} else {
|
||||||
|
Log.i(TAG, "File already exists $modelName")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the model file from the app private storage
|
||||||
|
*/
|
||||||
|
private suspend fun loadModel(modelName: String, modelFile: File) =
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
Log.i(TAG, "Loading model $modelName")
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
userInputEt.hint = "Loading model..."
|
||||||
|
}
|
||||||
|
engine.loadModel(modelFile.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate and send the user message into [InferenceEngine]
|
||||||
|
*/
|
||||||
|
private fun handleUserInput() {
|
||||||
|
userInputEt.text.toString().also { userSsg ->
|
||||||
|
if (userSsg.isEmpty()) {
|
||||||
|
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
|
||||||
|
} else {
|
||||||
|
userInputEt.text = null
|
||||||
|
userActionFab.isEnabled = false
|
||||||
|
|
||||||
|
// Update message states
|
||||||
|
messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
|
||||||
|
lastAssistantMsg.clear()
|
||||||
|
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
|
||||||
|
|
||||||
|
lifecycleScope.launch(Dispatchers.Default) {
|
||||||
|
engine.sendUserPrompt(userSsg)
|
||||||
|
.onCompletion {
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
userActionFab.isEnabled = true
|
||||||
|
}
|
||||||
|
}.collect { token ->
|
||||||
|
val messageCount = messages.size
|
||||||
|
check(messageCount > 0 && !messages[messageCount - 1].isUser)
|
||||||
|
|
||||||
|
messages.removeAt(messageCount - 1).copy(
|
||||||
|
content = lastAssistantMsg.append(token).toString()
|
||||||
|
).let { messages.add(it) }
|
||||||
|
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
messageAdapter.notifyItemChanged(messages.size - 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run a benchmark with the model file
|
||||||
|
*/
|
||||||
|
private suspend fun runBenchmark(modelName: String, modelFile: File) =
|
||||||
|
withContext(Dispatchers.Default) {
|
||||||
|
Log.i(TAG, "Starts benchmarking $modelName")
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
userInputEt.hint = "Running benchmark..."
|
||||||
|
}
|
||||||
|
engine.bench(
|
||||||
|
pp=BENCH_PROMPT_PROCESSING_TOKENS,
|
||||||
|
tg=BENCH_TOKEN_GENERATION_TOKENS,
|
||||||
|
pl=BENCH_SEQUENCE,
|
||||||
|
nr=BENCH_REPETITION
|
||||||
|
).let { result ->
|
||||||
|
messages.add(Message(UUID.randomUUID().toString(), result, false))
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
messageAdapter.notifyItemChanged(messages.size - 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the `models` directory if not exist.
|
||||||
|
*/
|
||||||
|
private fun ensureModelsDirectory() =
|
||||||
|
File(filesDir, DIRECTORY_MODELS).also {
|
||||||
|
if (it.exists() && !it.isDirectory) { it.delete() }
|
||||||
|
if (!it.exists()) { it.mkdir() }
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val TAG = MainActivity::class.java.simpleName
|
||||||
|
|
||||||
|
private const val DIRECTORY_MODELS = "models"
|
||||||
|
private const val FILE_EXTENSION_GGUF = ".gguf"
|
||||||
|
|
||||||
|
private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
|
||||||
|
private const val BENCH_TOKEN_GENERATION_TOKENS = 128
|
||||||
|
private const val BENCH_SEQUENCE = 1
|
||||||
|
private const val BENCH_REPETITION = 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun GgufMetadata.filename() = when {
|
||||||
|
basic.name != null -> {
|
||||||
|
basic.name?.let { name ->
|
||||||
|
basic.sizeLabel?.let { size ->
|
||||||
|
"$name-$size"
|
||||||
|
} ?: name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
architecture?.architecture != null -> {
|
||||||
|
architecture?.architecture?.let { arch ->
|
||||||
|
basic.uuid?.let { uuid ->
|
||||||
|
"$arch-$uuid"
|
||||||
|
} ?: "$arch-${System.currentTimeMillis()}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
"model-${System.currentTimeMillis().toHexString()}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,105 +0,0 @@
|
||||||
package com.example.llama
|
|
||||||
|
|
||||||
import android.llama.cpp.LLamaAndroid
|
|
||||||
import android.util.Log
|
|
||||||
import androidx.compose.runtime.getValue
|
|
||||||
import androidx.compose.runtime.mutableStateOf
|
|
||||||
import androidx.compose.runtime.setValue
|
|
||||||
import androidx.lifecycle.ViewModel
|
|
||||||
import androidx.lifecycle.viewModelScope
|
|
||||||
import kotlinx.coroutines.flow.catch
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
|
|
||||||
class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
|
|
||||||
companion object {
|
|
||||||
@JvmStatic
|
|
||||||
private val NanosPerSecond = 1_000_000_000.0
|
|
||||||
}
|
|
||||||
|
|
||||||
private val tag: String? = this::class.simpleName
|
|
||||||
|
|
||||||
var messages by mutableStateOf(listOf("Initializing..."))
|
|
||||||
private set
|
|
||||||
|
|
||||||
var message by mutableStateOf("")
|
|
||||||
private set
|
|
||||||
|
|
||||||
override fun onCleared() {
|
|
||||||
super.onCleared()
|
|
||||||
|
|
||||||
viewModelScope.launch {
|
|
||||||
try {
|
|
||||||
llamaAndroid.unload()
|
|
||||||
} catch (exc: IllegalStateException) {
|
|
||||||
messages += exc.message!!
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun send() {
|
|
||||||
val text = message
|
|
||||||
message = ""
|
|
||||||
|
|
||||||
// Add to messages console.
|
|
||||||
messages += text
|
|
||||||
messages += ""
|
|
||||||
|
|
||||||
viewModelScope.launch {
|
|
||||||
llamaAndroid.send(text)
|
|
||||||
.catch {
|
|
||||||
Log.e(tag, "send() failed", it)
|
|
||||||
messages += it.message!!
|
|
||||||
}
|
|
||||||
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
|
|
||||||
viewModelScope.launch {
|
|
||||||
try {
|
|
||||||
val start = System.nanoTime()
|
|
||||||
val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
|
|
||||||
val end = System.nanoTime()
|
|
||||||
|
|
||||||
messages += warmupResult
|
|
||||||
|
|
||||||
val warmup = (end - start).toDouble() / NanosPerSecond
|
|
||||||
messages += "Warm up time: $warmup seconds, please wait..."
|
|
||||||
|
|
||||||
if (warmup > 5.0) {
|
|
||||||
messages += "Warm up took too long, aborting benchmark"
|
|
||||||
return@launch
|
|
||||||
}
|
|
||||||
|
|
||||||
messages += llamaAndroid.bench(512, 128, 1, 3)
|
|
||||||
} catch (exc: IllegalStateException) {
|
|
||||||
Log.e(tag, "bench() failed", exc)
|
|
||||||
messages += exc.message!!
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun load(pathToModel: String) {
|
|
||||||
viewModelScope.launch {
|
|
||||||
try {
|
|
||||||
llamaAndroid.load(pathToModel)
|
|
||||||
messages += "Loaded $pathToModel"
|
|
||||||
} catch (exc: IllegalStateException) {
|
|
||||||
Log.e(tag, "load() failed", exc)
|
|
||||||
messages += exc.message!!
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun updateMessage(newMessage: String) {
|
|
||||||
message = newMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
fun clear() {
|
|
||||||
messages = listOf()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun log(message: String) {
|
|
||||||
messages += message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
package com.example.llama
|
||||||
|
|
||||||
|
import android.view.LayoutInflater
|
||||||
|
import android.view.View
|
||||||
|
import android.view.ViewGroup
|
||||||
|
import android.widget.TextView
|
||||||
|
import androidx.recyclerview.widget.RecyclerView
|
||||||
|
|
||||||
|
data class Message(
|
||||||
|
val id: String,
|
||||||
|
val content: String,
|
||||||
|
val isUser: Boolean
|
||||||
|
)
|
||||||
|
|
||||||
|
class MessageAdapter(
|
||||||
|
private val messages: List<Message>
|
||||||
|
) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val VIEW_TYPE_USER = 1
|
||||||
|
private const val VIEW_TYPE_ASSISTANT = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getItemViewType(position: Int): Int {
|
||||||
|
return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
|
||||||
|
val layoutInflater = LayoutInflater.from(parent.context)
|
||||||
|
return if (viewType == VIEW_TYPE_USER) {
|
||||||
|
val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
|
||||||
|
UserMessageViewHolder(view)
|
||||||
|
} else {
|
||||||
|
val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
|
||||||
|
AssistantMessageViewHolder(view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
|
||||||
|
val message = messages[position]
|
||||||
|
if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
|
||||||
|
val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
|
||||||
|
textView.text = message.content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getItemCount(): Int = messages.size
|
||||||
|
|
||||||
|
class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
|
||||||
|
class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
|
||||||
|
}
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
package com.example.llama.ui.theme
|
|
||||||
|
|
||||||
import androidx.compose.ui.graphics.Color
|
|
||||||
|
|
||||||
val Purple80 = Color(0xFFD0BCFF)
|
|
||||||
val PurpleGrey80 = Color(0xFFCCC2DC)
|
|
||||||
val Pink80 = Color(0xFFEFB8C8)
|
|
||||||
|
|
||||||
val Purple40 = Color(0xFF6650a4)
|
|
||||||
val PurpleGrey40 = Color(0xFF625b71)
|
|
||||||
val Pink40 = Color(0xFF7D5260)
|
|
||||||
|
|
@ -1,70 +0,0 @@
|
||||||
package com.example.llama.ui.theme
|
|
||||||
|
|
||||||
import android.app.Activity
|
|
||||||
import android.os.Build
|
|
||||||
import androidx.compose.foundation.isSystemInDarkTheme
|
|
||||||
import androidx.compose.material3.MaterialTheme
|
|
||||||
import androidx.compose.material3.darkColorScheme
|
|
||||||
import androidx.compose.material3.dynamicDarkColorScheme
|
|
||||||
import androidx.compose.material3.dynamicLightColorScheme
|
|
||||||
import androidx.compose.material3.lightColorScheme
|
|
||||||
import androidx.compose.runtime.Composable
|
|
||||||
import androidx.compose.runtime.SideEffect
|
|
||||||
import androidx.compose.ui.graphics.toArgb
|
|
||||||
import androidx.compose.ui.platform.LocalContext
|
|
||||||
import androidx.compose.ui.platform.LocalView
|
|
||||||
import androidx.core.view.WindowCompat
|
|
||||||
|
|
||||||
private val DarkColorScheme = darkColorScheme(
|
|
||||||
primary = Purple80,
|
|
||||||
secondary = PurpleGrey80,
|
|
||||||
tertiary = Pink80
|
|
||||||
)
|
|
||||||
|
|
||||||
private val LightColorScheme = lightColorScheme(
|
|
||||||
primary = Purple40,
|
|
||||||
secondary = PurpleGrey40,
|
|
||||||
tertiary = Pink40
|
|
||||||
|
|
||||||
/* Other default colors to override
|
|
||||||
background = Color(0xFFFFFBFE),
|
|
||||||
surface = Color(0xFFFFFBFE),
|
|
||||||
onPrimary = Color.White,
|
|
||||||
onSecondary = Color.White,
|
|
||||||
onTertiary = Color.White,
|
|
||||||
onBackground = Color(0xFF1C1B1F),
|
|
||||||
onSurface = Color(0xFF1C1B1F),
|
|
||||||
*/
|
|
||||||
)
|
|
||||||
|
|
||||||
@Composable
|
|
||||||
fun LlamaAndroidTheme(
|
|
||||||
darkTheme: Boolean = isSystemInDarkTheme(),
|
|
||||||
// Dynamic color is available on Android 12+
|
|
||||||
dynamicColor: Boolean = true,
|
|
||||||
content: @Composable () -> Unit
|
|
||||||
) {
|
|
||||||
val colorScheme = when {
|
|
||||||
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
|
|
||||||
val context = LocalContext.current
|
|
||||||
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
darkTheme -> DarkColorScheme
|
|
||||||
else -> LightColorScheme
|
|
||||||
}
|
|
||||||
val view = LocalView.current
|
|
||||||
if (!view.isInEditMode) {
|
|
||||||
SideEffect {
|
|
||||||
val window = (view.context as Activity).window
|
|
||||||
window.statusBarColor = colorScheme.primary.toArgb()
|
|
||||||
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MaterialTheme(
|
|
||||||
colorScheme = colorScheme,
|
|
||||||
typography = Typography,
|
|
||||||
content = content
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
package com.example.llama.ui.theme
|
|
||||||
|
|
||||||
import androidx.compose.material3.Typography
|
|
||||||
import androidx.compose.ui.text.TextStyle
|
|
||||||
import androidx.compose.ui.text.font.FontFamily
|
|
||||||
import androidx.compose.ui.text.font.FontWeight
|
|
||||||
import androidx.compose.ui.unit.sp
|
|
||||||
|
|
||||||
// Set of Material typography styles to start with
|
|
||||||
val Typography = Typography(
|
|
||||||
bodyLarge = TextStyle(
|
|
||||||
fontFamily = FontFamily.Default,
|
|
||||||
fontWeight = FontWeight.Normal,
|
|
||||||
fontSize = 16.sp,
|
|
||||||
lineHeight = 24.sp,
|
|
||||||
letterSpacing = 0.5.sp
|
|
||||||
)
|
|
||||||
/* Other default text styles to override
|
|
||||||
titleLarge = TextStyle(
|
|
||||||
fontFamily = FontFamily.Default,
|
|
||||||
fontWeight = FontWeight.Normal,
|
|
||||||
fontSize = 22.sp,
|
|
||||||
lineHeight = 28.sp,
|
|
||||||
letterSpacing = 0.sp
|
|
||||||
),
|
|
||||||
labelSmall = TextStyle(
|
|
||||||
fontFamily = FontFamily.Default,
|
|
||||||
fontWeight = FontWeight.Medium,
|
|
||||||
fontSize = 11.sp,
|
|
||||||
lineHeight = 16.sp,
|
|
||||||
letterSpacing = 0.5.sp
|
|
||||||
)
|
|
||||||
*/
|
|
||||||
)
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
|
||||||
|
<solid android:color="#E5E5EA" />
|
||||||
|
<corners android:radius="16dp" />
|
||||||
|
</shape>
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
|
||||||
|
<solid android:color="#4285F4" />
|
||||||
|
<corners android:radius="16dp" />
|
||||||
|
</shape>
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
android:width="24dp"
|
||||||
|
android:height="24dp"
|
||||||
|
android:viewportWidth="24"
|
||||||
|
android:viewportHeight="24"
|
||||||
|
android:tint="?attr/colorControlNormal">
|
||||||
|
<path
|
||||||
|
android:fillColor="@android:color/white"
|
||||||
|
android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/>
|
||||||
|
</vector>
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
android:width="24dp"
|
||||||
|
android:height="24dp"
|
||||||
|
android:viewportWidth="24"
|
||||||
|
android:viewportHeight="24"
|
||||||
|
android:tint="?attr/colorControlNormal"
|
||||||
|
android:autoMirrored="true">
|
||||||
|
<path
|
||||||
|
android:fillColor="@android:color/white"
|
||||||
|
android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/>
|
||||||
|
</vector>
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||||
|
xmlns:tools="http://schemas.android.com/tools"
|
||||||
|
android:id="@+id/main"
|
||||||
|
android:layout_height="match_parent"
|
||||||
|
android:layout_width="match_parent">
|
||||||
|
|
||||||
|
<LinearLayout
|
||||||
|
android:fitsSystemWindows="true"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="match_parent"
|
||||||
|
android:orientation="vertical"
|
||||||
|
tools:context=".MainActivity">
|
||||||
|
|
||||||
|
<FrameLayout
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="0dp"
|
||||||
|
android:layout_weight="1">
|
||||||
|
|
||||||
|
<ScrollView
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:fadeScrollbars="false">
|
||||||
|
|
||||||
|
<TextView
|
||||||
|
android:id="@+id/gguf"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:layout_margin="16dp"
|
||||||
|
android:text="Selected GGUF model's metadata will show here."
|
||||||
|
style="@style/TextAppearance.MaterialComponents.Body2"
|
||||||
|
android:maxLines="100" />
|
||||||
|
|
||||||
|
</ScrollView>
|
||||||
|
|
||||||
|
</FrameLayout>
|
||||||
|
|
||||||
|
<androidx.recyclerview.widget.RecyclerView
|
||||||
|
android:id="@+id/messages"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="0dp"
|
||||||
|
android:layout_weight="4"
|
||||||
|
android:padding="16dp"
|
||||||
|
android:fadeScrollbars="false"
|
||||||
|
app:reverseLayout="true"
|
||||||
|
tools:listitem="@layout/item_message_assistant"/>
|
||||||
|
|
||||||
|
<LinearLayout
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:orientation="horizontal">
|
||||||
|
|
||||||
|
<EditText
|
||||||
|
android:id="@+id/user_input"
|
||||||
|
android:enabled="false"
|
||||||
|
android:layout_width="0dp"
|
||||||
|
android:layout_weight="1"
|
||||||
|
android:layout_height="match_parent"
|
||||||
|
android:padding="8dp"
|
||||||
|
style="@style/TextAppearance.MaterialComponents.Body2"
|
||||||
|
android:hint="Please first pick a GGUF model file to import." />
|
||||||
|
|
||||||
|
<com.google.android.material.floatingactionbutton.FloatingActionButton
|
||||||
|
android:id="@+id/fab"
|
||||||
|
android:enabled="true"
|
||||||
|
style="@style/Widget.Material3.FloatingActionButton.Primary"
|
||||||
|
android:layout_width="wrap_content"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:layout_margin="8dp"
|
||||||
|
android:src="@drawable/outline_folder_open_24" />
|
||||||
|
|
||||||
|
</LinearLayout>
|
||||||
|
|
||||||
|
</LinearLayout>
|
||||||
|
</androidx.constraintlayout.widget.ConstraintLayout>
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:padding="8dp"
|
||||||
|
android:gravity="start">
|
||||||
|
|
||||||
|
<TextView
|
||||||
|
android:id="@+id/msg_content"
|
||||||
|
android:layout_width="wrap_content"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:background="@drawable/bg_assistant_message"
|
||||||
|
android:padding="12dp"
|
||||||
|
android:textColor="@android:color/black" />
|
||||||
|
</LinearLayout>
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:padding="8dp"
|
||||||
|
android:gravity="end">
|
||||||
|
|
||||||
|
<TextView
|
||||||
|
android:id="@+id/msg_content"
|
||||||
|
android:layout_width="wrap_content"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:background="@drawable/bg_user_message"
|
||||||
|
android:padding="12dp"
|
||||||
|
android:textColor="@android:color/white" />
|
||||||
|
</LinearLayout>
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">LlamaAndroid</string>
|
<string name="app_name">AI Chat basic sample</string>
|
||||||
</resources>
|
</resources>
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
<?xml version="1.0" encoding="utf-8"?>
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
<resources>
|
<resources>
|
||||||
|
|
||||||
<style name="Theme.LlamaAndroid" parent="android:Theme.Material.Light.NoActionBar" />
|
<style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar">
|
||||||
|
<!-- Customize your light theme here. -->
|
||||||
|
<!-- <item name="colorPrimary">@color/my_light_primary</item> -->
|
||||||
|
</style>
|
||||||
|
|
||||||
|
<style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" />
|
||||||
</resources>
|
</resources>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||||
plugins {
|
plugins {
|
||||||
id("com.android.application") version "8.2.0" apply false
|
alias(libs.plugins.android.application) apply false
|
||||||
id("org.jetbrains.kotlin.android") version "1.9.0" apply false
|
alias(libs.plugins.android.library) apply false
|
||||||
id("com.android.library") version "8.2.0" apply false
|
alias(libs.plugins.jetbrains.kotlin.android) apply false
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,3 +21,4 @@ kotlin.code.style=official
|
||||||
# resources declared in the library itself and none from the library's dependencies,
|
# resources declared in the library itself and none from the library's dependencies,
|
||||||
# thereby reducing the size of the R class for that library
|
# thereby reducing the size of the R class for that library
|
||||||
android.nonTransitiveRClass=true
|
android.nonTransitiveRClass=true
|
||||||
|
android.native.buildOutput=verbose
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
[versions]
|
||||||
|
|
||||||
|
# Plugins
|
||||||
|
agp = "8.13.0"
|
||||||
|
kotlin = "2.2.20"
|
||||||
|
|
||||||
|
# AndroidX
|
||||||
|
activity = "1.11.0"
|
||||||
|
appcompat = "1.7.1"
|
||||||
|
core-ktx = "1.17.0"
|
||||||
|
constraint-layout = "2.2.1"
|
||||||
|
datastore-preferences = "1.1.7"
|
||||||
|
|
||||||
|
# Material
|
||||||
|
material = "1.13.0"
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
espresso-core = "3.7.0"
|
||||||
|
androidx-junit = "1.3.0"
|
||||||
|
junit = "4.13.2"
|
||||||
|
|
||||||
|
|
||||||
|
[plugins]
|
||||||
|
android-application = { id = "com.android.application", version.ref = "agp" }
|
||||||
|
android-library = { id = "com.android.library", version.ref = "agp" }
|
||||||
|
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
|
||||||
|
|
||||||
|
|
||||||
|
[libraries]
|
||||||
|
|
||||||
|
# AndroidX
|
||||||
|
androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
|
||||||
|
androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
|
||||||
|
androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
|
||||||
|
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
|
||||||
|
androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
|
||||||
|
|
||||||
|
#Material
|
||||||
|
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
|
||||||
|
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
|
||||||
|
junit = { group = "junit", name = "junit", version.ref = "junit" }
|
||||||
|
|
||||||
|
[bundles]
|
||||||
|
androidx = [
|
||||||
|
"androidx-activity",
|
||||||
|
"androidx-appcompat",
|
||||||
|
"androidx-constraintlayout",
|
||||||
|
"androidx-core-ktx",
|
||||||
|
"androidx-datastore-preferences",
|
||||||
|
]
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#Thu Dec 21 14:31:09 AEDT 2023
|
#Tue Apr 01 11:15:06 PDT 2025
|
||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
plugins {
|
||||||
|
alias(libs.plugins.android.library)
|
||||||
|
alias(libs.plugins.jetbrains.kotlin.android)
|
||||||
|
}
|
||||||
|
|
||||||
|
android {
|
||||||
|
namespace = "com.arm.aichat"
|
||||||
|
compileSdk = 36
|
||||||
|
|
||||||
|
ndkVersion = "29.0.13113456"
|
||||||
|
|
||||||
|
defaultConfig {
|
||||||
|
minSdk = 33
|
||||||
|
|
||||||
|
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
|
||||||
|
consumerProguardFiles("consumer-rules.pro")
|
||||||
|
|
||||||
|
ndk {
|
||||||
|
abiFilters += listOf("arm64-v8a", "x86_64")
|
||||||
|
}
|
||||||
|
externalNativeBuild {
|
||||||
|
cmake {
|
||||||
|
arguments += "-DCMAKE_BUILD_TYPE=Release"
|
||||||
|
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
|
||||||
|
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
|
||||||
|
|
||||||
|
arguments += "-DBUILD_SHARED_LIBS=ON"
|
||||||
|
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
||||||
|
arguments += "-DLLAMA_CURL=OFF"
|
||||||
|
|
||||||
|
arguments += "-DGGML_NATIVE=OFF"
|
||||||
|
arguments += "-DGGML_BACKEND_DL=ON"
|
||||||
|
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
|
||||||
|
arguments += "-DGGML_LLAMAFILE=OFF"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
aarMetadata {
|
||||||
|
minCompileSdk = 35
|
||||||
|
}
|
||||||
|
}
|
||||||
|
externalNativeBuild {
|
||||||
|
cmake {
|
||||||
|
path("src/main/cpp/CMakeLists.txt")
|
||||||
|
version = "3.31.6"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
compileOptions {
|
||||||
|
sourceCompatibility = JavaVersion.VERSION_17
|
||||||
|
targetCompatibility = JavaVersion.VERSION_17
|
||||||
|
}
|
||||||
|
kotlin {
|
||||||
|
jvmToolchain(17)
|
||||||
|
|
||||||
|
compileOptions {
|
||||||
|
targetCompatibility = JavaVersion.VERSION_17
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
packaging {
|
||||||
|
resources {
|
||||||
|
excludes += "/META-INF/{AL2.0,LGPL2.1}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
publishing {
|
||||||
|
singleVariant("release") {
|
||||||
|
withJavadocJar()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation(libs.androidx.core.ktx)
|
||||||
|
implementation(libs.androidx.datastore.preferences)
|
||||||
|
|
||||||
|
testImplementation(libs.junit)
|
||||||
|
androidTestImplementation(libs.androidx.junit)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
-keep class com.arm.aichat.* { *; }
|
||||||
|
-keep class com.arm.aichat.gguf.* { *; }
|
||||||
|
|
||||||
|
-keepclasseswithmembernames class * {
|
||||||
|
native <methods>;
|
||||||
|
}
|
||||||
|
|
||||||
|
-keep class kotlin.Metadata { *; }
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
cmake_minimum_required(VERSION 3.31.6)
|
||||||
|
|
||||||
|
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
|
||||||
|
|
||||||
|
set(CMAKE_C_STANDARD 11)
|
||||||
|
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||||
|
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# AI Chat library
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if(DEFINED ANDROID_ABI)
|
||||||
|
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
|
||||||
|
if(ANDROID_ABI STREQUAL "arm64-v8a")
|
||||||
|
set(GGML_SYSTEM_ARCH "ARM")
|
||||||
|
set(GGML_CPU_KLEIDIAI ON)
|
||||||
|
set(GGML_OPENMP ON)
|
||||||
|
elseif(ANDROID_ABI STREQUAL "x86_64")
|
||||||
|
set(GGML_SYSTEM_ARCH "x86")
|
||||||
|
set(GGML_CPU_KLEIDIAI OFF)
|
||||||
|
set(GGML_OPENMP OFF)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
|
||||||
|
add_subdirectory(${LLAMA_SRC} build-llama)
|
||||||
|
|
||||||
|
add_library(${CMAKE_PROJECT_NAME} SHARED
|
||||||
|
ai_chat.cpp)
|
||||||
|
|
||||||
|
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
|
||||||
|
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
|
||||||
|
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
|
||||||
|
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
|
||||||
|
${LLAMA_SRC}
|
||||||
|
${LLAMA_SRC}/common
|
||||||
|
${LLAMA_SRC}/include
|
||||||
|
${LLAMA_SRC}/ggml/include
|
||||||
|
${LLAMA_SRC}/ggml/src)
|
||||||
|
|
||||||
|
target_link_libraries(${CMAKE_PROJECT_NAME}
|
||||||
|
llama
|
||||||
|
common
|
||||||
|
android
|
||||||
|
log)
|
||||||
|
|
@ -0,0 +1,565 @@
|
||||||
|
#include <android/log.h>
|
||||||
|
#include <jni.h>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <sampling.h>
|
||||||
|
|
||||||
|
#include "logging.h"
|
||||||
|
#include "chat.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||||
|
std::ostringstream str;
|
||||||
|
for (size_t i = 0; i < values.size(); i++) {
|
||||||
|
str << values[i];
|
||||||
|
if (i < values.size() - 1) { str << delim; }
|
||||||
|
}
|
||||||
|
return str.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LLama resources: context, model, batch and sampler
|
||||||
|
*/
|
||||||
|
constexpr int N_THREADS_MIN = 2;
|
||||||
|
constexpr int N_THREADS_MAX = 4;
|
||||||
|
constexpr int N_THREADS_HEADROOM = 2;
|
||||||
|
|
||||||
|
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
|
||||||
|
constexpr int OVERFLOW_HEADROOM = 4;
|
||||||
|
constexpr int BATCH_SIZE = 512;
|
||||||
|
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
|
||||||
|
|
||||||
|
static llama_model * g_model;
|
||||||
|
static llama_context * g_context;
|
||||||
|
static llama_batch g_batch;
|
||||||
|
static common_chat_templates_ptr g_chat_templates;
|
||||||
|
static common_sampler * g_sampler;
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
|
||||||
|
// Set llama log handler to Android
|
||||||
|
llama_log_set(aichat_android_log_callback, nullptr);
|
||||||
|
|
||||||
|
// Loading all CPU backend variants
|
||||||
|
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
|
||||||
|
LOGi("Loading backends from %s", path_to_backend);
|
||||||
|
ggml_backend_load_all_from_path(path_to_backend);
|
||||||
|
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
|
||||||
|
|
||||||
|
// Initialize backends
|
||||||
|
llama_backend_init();
|
||||||
|
LOGi("Backend initiated; Log handler set.");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
|
||||||
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
||||||
|
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
|
||||||
|
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
|
||||||
|
|
||||||
|
auto *model = llama_model_load_from_file(model_path, model_params);
|
||||||
|
env->ReleaseStringUTFChars(jmodel_path, model_path);
|
||||||
|
if (!model) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
g_model = model;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
|
||||||
|
if (!model) {
|
||||||
|
LOGe("%s: model cannot be null", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-threading setup
|
||||||
|
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||||
|
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||||
|
N_THREADS_HEADROOM));
|
||||||
|
LOGi("%s: Using %d threads", __func__, n_threads);
|
||||||
|
|
||||||
|
// Context parameters setup
|
||||||
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
const int trained_context_size = llama_model_n_ctx_train(model);
|
||||||
|
if (n_ctx > trained_context_size) {
|
||||||
|
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
|
||||||
|
__func__, trained_context_size, n_ctx);
|
||||||
|
}
|
||||||
|
ctx_params.n_ctx = n_ctx;
|
||||||
|
ctx_params.n_batch = BATCH_SIZE;
|
||||||
|
ctx_params.n_ubatch = BATCH_SIZE;
|
||||||
|
ctx_params.n_threads = n_threads;
|
||||||
|
ctx_params.n_threads_batch = n_threads;
|
||||||
|
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||||
|
if (context == nullptr) {
|
||||||
|
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_sampler *new_sampler(float temp) {
|
||||||
|
common_params_sampling sparams;
|
||||||
|
sparams.temp = temp;
|
||||||
|
return common_sampler_init(g_model, sparams);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||||
|
auto *context = init_context(g_model);
|
||||||
|
if (!context) { return 1; }
|
||||||
|
g_context = context;
|
||||||
|
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
|
||||||
|
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||||
|
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string get_backend() {
|
||||||
|
std::vector<std::string> backends;
|
||||||
|
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||||
|
auto *reg = ggml_backend_reg_get(i);
|
||||||
|
std::string name = ggml_backend_reg_name(reg);
|
||||||
|
if (name != "CPU") {
|
||||||
|
backends.push_back(ggml_backend_reg_name(reg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return backends.empty() ? "CPU" : join(backends, ",");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jstring JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||||
|
return env->NewStringUTF(llama_print_system_info());
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jstring JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||||
|
jint pl, jint nr) {
|
||||||
|
auto *context = init_context(g_model, pp);
|
||||||
|
if (!context) {
|
||||||
|
const auto *const err_msg = "Fail to init_context! Bench aborted.";
|
||||||
|
LOGe(err_msg);
|
||||||
|
return env->NewStringUTF(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pp_avg = 0.0;
|
||||||
|
auto tg_avg = 0.0;
|
||||||
|
auto pp_std = 0.0;
|
||||||
|
auto tg_std = 0.0;
|
||||||
|
|
||||||
|
const uint32_t n_ctx = llama_n_ctx(context);
|
||||||
|
LOGi("n_ctx = %d", n_ctx);
|
||||||
|
|
||||||
|
int i, j;
|
||||||
|
int nri;
|
||||||
|
for (nri = 0; nri < nr; nri++) {
|
||||||
|
LOGi("Benchmark prompt processing (pp = %d)", pp);
|
||||||
|
|
||||||
|
common_batch_clear(g_batch);
|
||||||
|
|
||||||
|
const int n_tokens = pp;
|
||||||
|
for (i = 0; i < n_tokens; i++) {
|
||||||
|
common_batch_add(g_batch, 0, i, {0}, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
g_batch.logits[g_batch.n_tokens - 1] = true;
|
||||||
|
llama_memory_clear(llama_get_memory(context), false);
|
||||||
|
|
||||||
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
if (llama_decode(context, g_batch) != 0) {
|
||||||
|
LOGe("llama_decode() failed during prompt processing");
|
||||||
|
}
|
||||||
|
const auto t_pp_end = ggml_time_us();
|
||||||
|
|
||||||
|
// bench text generation
|
||||||
|
|
||||||
|
LOGi("Benchmark text generation (tg = %d)", tg);
|
||||||
|
|
||||||
|
llama_memory_clear(llama_get_memory(context), false);
|
||||||
|
const auto t_tg_start = ggml_time_us();
|
||||||
|
for (i = 0; i < tg; i++) {
|
||||||
|
common_batch_clear(g_batch);
|
||||||
|
for (j = 0; j < pl; j++) {
|
||||||
|
common_batch_add(g_batch, 0, i, {j}, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(context, g_batch) != 0) {
|
||||||
|
LOGe("llama_decode() failed during text generation");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const auto t_tg_end = ggml_time_us();
|
||||||
|
|
||||||
|
llama_memory_clear(llama_get_memory(context), false);
|
||||||
|
|
||||||
|
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||||
|
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||||
|
|
||||||
|
const auto speed_pp = double(pp) / t_pp;
|
||||||
|
const auto speed_tg = double(pl * tg) / t_tg;
|
||||||
|
|
||||||
|
pp_avg += speed_pp;
|
||||||
|
tg_avg += speed_tg;
|
||||||
|
|
||||||
|
pp_std += speed_pp * speed_pp;
|
||||||
|
tg_std += speed_tg * speed_tg;
|
||||||
|
|
||||||
|
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_free(context);
|
||||||
|
|
||||||
|
pp_avg /= double(nr);
|
||||||
|
tg_avg /= double(nr);
|
||||||
|
|
||||||
|
if (nr > 1) {
|
||||||
|
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
|
||||||
|
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
|
||||||
|
} else {
|
||||||
|
pp_std = 0;
|
||||||
|
tg_std = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
char model_desc[128];
|
||||||
|
llama_model_desc(g_model, model_desc, sizeof(model_desc));
|
||||||
|
|
||||||
|
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
|
||||||
|
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
|
||||||
|
|
||||||
|
const auto backend = get_backend();
|
||||||
|
std::stringstream result;
|
||||||
|
result << std::setprecision(3);
|
||||||
|
result << "| model | size | params | backend | test | t/s |\n";
|
||||||
|
result << "| --- | --- | --- | --- | --- | --- |\n";
|
||||||
|
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||||
|
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
||||||
|
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||||
|
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
||||||
|
return env->NewStringUTF(result.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Completion loop's long-term states:
|
||||||
|
* - chat management
|
||||||
|
* - position tracking
|
||||||
|
*/
|
||||||
|
constexpr const char *ROLE_SYSTEM = "system";
|
||||||
|
constexpr const char *ROLE_USER = "user";
|
||||||
|
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||||
|
|
||||||
|
static std::vector<common_chat_msg> chat_msgs;
|
||||||
|
static llama_pos system_prompt_position;
|
||||||
|
static llama_pos current_position;
|
||||||
|
|
||||||
|
static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||||
|
chat_msgs.clear();
|
||||||
|
system_prompt_position = 0;
|
||||||
|
current_position = 0;
|
||||||
|
|
||||||
|
if (clear_kv_cache)
|
||||||
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO-hyin: implement sliding-window version as a better alternative
|
||||||
|
*
|
||||||
|
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||||
|
* - take the [system_prompt_position] first tokens from the original prompt
|
||||||
|
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||||
|
* - recompute the logits in batches
|
||||||
|
*/
|
||||||
|
static void shift_context() {
|
||||||
|
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||||
|
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||||
|
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||||
|
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||||
|
current_position -= n_discard;
|
||||||
|
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||||
|
common_chat_msg new_msg;
|
||||||
|
new_msg.role = role;
|
||||||
|
new_msg.content = content;
|
||||||
|
auto formatted = common_chat_format_single(
|
||||||
|
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||||
|
chat_msgs.push_back(new_msg);
|
||||||
|
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
|
||||||
|
return formatted;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Completion loop's short-term states:
|
||||||
|
* - stop generation position
|
||||||
|
* - token chars caching
|
||||||
|
* - current assistant message being generated
|
||||||
|
*/
|
||||||
|
static llama_pos stop_generation_position;
|
||||||
|
static std::string cached_token_chars;
|
||||||
|
static std::ostringstream assistant_ss;
|
||||||
|
|
||||||
|
static void reset_short_term_states() {
|
||||||
|
stop_generation_position = 0;
|
||||||
|
cached_token_chars.clear();
|
||||||
|
assistant_ss.str("");
|
||||||
|
}
|
||||||
|
|
||||||
|
static int decode_tokens_in_batches(
|
||||||
|
llama_context *context,
|
||||||
|
llama_batch &batch,
|
||||||
|
const llama_tokens &tokens,
|
||||||
|
const llama_pos start_pos,
|
||||||
|
const bool compute_last_logit = false) {
|
||||||
|
// Process tokens in batches using the global batch
|
||||||
|
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
|
||||||
|
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||||
|
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||||
|
common_batch_clear(batch);
|
||||||
|
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
|
||||||
|
|
||||||
|
// Shift context if current batch cannot fit into the context
|
||||||
|
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||||
|
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
|
||||||
|
shift_context();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tokens to the batch with proper positions
|
||||||
|
for (int j = 0; j < cur_batch_size; j++) {
|
||||||
|
const llama_token token_id = tokens[i + j];
|
||||||
|
const llama_pos position = start_pos + i + j;
|
||||||
|
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||||
|
common_batch_add(batch, token_id, position, {0}, want_logit);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode this batch
|
||||||
|
const int decode_result = llama_decode(context, batch);
|
||||||
|
if (decode_result) {
|
||||||
|
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
|
||||||
|
JNIEnv *env,
|
||||||
|
jobject /*unused*/,
|
||||||
|
jstring jsystem_prompt
|
||||||
|
) {
|
||||||
|
// Reset long-term & short-term states
|
||||||
|
reset_long_term_states();
|
||||||
|
reset_short_term_states();
|
||||||
|
|
||||||
|
// Obtain system prompt from JEnv
|
||||||
|
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||||
|
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
|
||||||
|
std::string formatted_system_prompt(system_prompt);
|
||||||
|
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||||
|
|
||||||
|
// Format system prompt if applicable
|
||||||
|
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||||
|
if (has_chat_template) {
|
||||||
|
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize system prompt
|
||||||
|
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||||
|
has_chat_template, has_chat_template);
|
||||||
|
for (auto id: system_tokens) {
|
||||||
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle context overflow
|
||||||
|
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
|
if ((int) system_tokens.size() > max_batch_size) {
|
||||||
|
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
|
||||||
|
__func__, (int) system_tokens.size(), max_batch_size);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode system tokens in batches
|
||||||
|
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
|
||||||
|
LOGe("%s: llama_decode() failed!", __func__);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update position
|
||||||
|
system_prompt_position = current_position = (int) system_tokens.size();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
|
||||||
|
JNIEnv *env,
|
||||||
|
jobject /*unused*/,
|
||||||
|
jstring juser_prompt,
|
||||||
|
jint n_predict
|
||||||
|
) {
|
||||||
|
// Reset short-term states
|
||||||
|
reset_short_term_states();
|
||||||
|
|
||||||
|
// Obtain and tokenize user prompt
|
||||||
|
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||||
|
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
|
||||||
|
std::string formatted_user_prompt(user_prompt);
|
||||||
|
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||||
|
|
||||||
|
// Format user prompt if applicable
|
||||||
|
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||||
|
if (has_chat_template) {
|
||||||
|
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode formatted user prompts
|
||||||
|
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||||
|
for (auto id: user_tokens) {
|
||||||
|
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||||
|
const int user_prompt_size = (int) user_tokens.size();
|
||||||
|
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
|
if (user_prompt_size > max_batch_size) {
|
||||||
|
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||||
|
user_tokens.resize(max_batch_size);
|
||||||
|
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode user tokens in batches
|
||||||
|
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
|
||||||
|
LOGe("%s: llama_decode() failed!", __func__);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update position
|
||||||
|
current_position += user_prompt_size;
|
||||||
|
stop_generation_position = current_position + user_prompt_size + n_predict;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool is_valid_utf8(const char *string) {
|
||||||
|
if (!string) { return true; }
|
||||||
|
|
||||||
|
const auto *bytes = (const unsigned char *) string;
|
||||||
|
int num;
|
||||||
|
|
||||||
|
while (*bytes != 0x00) {
|
||||||
|
if ((*bytes & 0x80) == 0x00) {
|
||||||
|
// U+0000 to U+007F
|
||||||
|
num = 1;
|
||||||
|
} else if ((*bytes & 0xE0) == 0xC0) {
|
||||||
|
// U+0080 to U+07FF
|
||||||
|
num = 2;
|
||||||
|
} else if ((*bytes & 0xF0) == 0xE0) {
|
||||||
|
// U+0800 to U+FFFF
|
||||||
|
num = 3;
|
||||||
|
} else if ((*bytes & 0xF8) == 0xF0) {
|
||||||
|
// U+10000 to U+10FFFF
|
||||||
|
num = 4;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes += 1;
|
||||||
|
for (int i = 1; i < num; ++i) {
|
||||||
|
if ((*bytes & 0xC0) != 0x80) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bytes += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jstring JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
|
||||||
|
JNIEnv *env,
|
||||||
|
jobject /*unused*/
|
||||||
|
) {
|
||||||
|
// Infinite text generation via context shifting
|
||||||
|
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||||
|
LOGw("%s: Context full! Shifting...", __func__);
|
||||||
|
shift_context();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop if reaching the marked position
|
||||||
|
if (current_position >= stop_generation_position) {
|
||||||
|
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample next token
|
||||||
|
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
||||||
|
common_sampler_accept(g_sampler, new_token_id, true);
|
||||||
|
|
||||||
|
// Populate the batch with new token, then decode
|
||||||
|
common_batch_clear(g_batch);
|
||||||
|
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
|
||||||
|
if (llama_decode(g_context, g_batch) != 0) {
|
||||||
|
LOGe("%s: llama_decode() failed for generated token", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update position
|
||||||
|
current_position++;
|
||||||
|
|
||||||
|
// Stop if next token is EOG
|
||||||
|
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||||
|
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||||
|
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not EOG, convert to text
|
||||||
|
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||||
|
cached_token_chars += new_token_chars;
|
||||||
|
|
||||||
|
// Create and return a valid UTF-8 Java string
|
||||||
|
jstring result = nullptr;
|
||||||
|
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||||
|
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||||
|
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||||
|
|
||||||
|
assistant_ss << cached_token_chars;
|
||||||
|
cached_token_chars.clear();
|
||||||
|
} else {
|
||||||
|
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||||
|
result = env->NewStringUTF("");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||||
|
// Reset long-term & short-term states
|
||||||
|
reset_long_term_states();
|
||||||
|
reset_short_term_states();
|
||||||
|
|
||||||
|
// Free up resources
|
||||||
|
common_sampler_free(g_sampler);
|
||||||
|
g_chat_templates.reset();
|
||||||
|
llama_batch_free(g_batch);
|
||||||
|
llama_free(g_context);
|
||||||
|
llama_model_free(g_model);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
|
||||||
|
llama_backend_free();
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
//
|
||||||
|
// Created by Han Yin on 10/31/25.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef AICHAT_LOGGING_H
|
||||||
|
#define AICHAT_LOGGING_H
|
||||||
|
|
||||||
|
#endif //AICHAT_LOGGING_H
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <android/log.h>
|
||||||
|
|
||||||
|
#ifndef LOG_TAG
|
||||||
|
#define LOG_TAG "ai-chat"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef LOG_MIN_LEVEL
|
||||||
|
#if defined(NDEBUG)
|
||||||
|
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
|
||||||
|
#else
|
||||||
|
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static inline int ai_should_log(int prio) {
|
||||||
|
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
|
||||||
|
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
|
||||||
|
#else
|
||||||
|
#define LOGv(...) ((void)0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
|
||||||
|
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
|
||||||
|
#else
|
||||||
|
#define LOGd(...) ((void)0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
|
||||||
|
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
|
||||||
|
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
|
||||||
|
|
||||||
|
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
|
||||||
|
switch (level) {
|
||||||
|
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
|
||||||
|
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
|
||||||
|
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
|
||||||
|
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
|
||||||
|
default: return ANDROID_LOG_DEFAULT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void aichat_android_log_callback(enum ggml_log_level level,
|
||||||
|
const char* text,
|
||||||
|
void* /*user*/) {
|
||||||
|
const int prio = android_log_prio_from_ggml(level);
|
||||||
|
if (!ai_should_log(prio)) return;
|
||||||
|
__android_log_write(prio, LOG_TAG, text);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
package com.arm.aichat
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import com.arm.aichat.internal.InferenceEngineImpl
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Main entry point for Arm's AI Chat library.
|
||||||
|
*/
|
||||||
|
object AiChat {
|
||||||
|
/**
|
||||||
|
* Get the inference engine single instance.
|
||||||
|
*/
|
||||||
|
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
package com.arm.aichat
|
||||||
|
|
||||||
|
import com.arm.aichat.InferenceEngine.State
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface defining the core LLM inference operations.
|
||||||
|
*/
|
||||||
|
interface InferenceEngine {
|
||||||
|
/**
|
||||||
|
* Current state of the inference engine
|
||||||
|
*/
|
||||||
|
val state: StateFlow<State>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a model from the given path.
|
||||||
|
*
|
||||||
|
* @throws UnsupportedArchitectureException if model architecture not supported
|
||||||
|
*/
|
||||||
|
suspend fun loadModel(pathToModel: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a system prompt to the loaded model
|
||||||
|
*/
|
||||||
|
suspend fun setSystemPrompt(systemPrompt: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||||
|
*/
|
||||||
|
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs a benchmark with the specified parameters.
|
||||||
|
*/
|
||||||
|
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unloads the currently loaded model.
|
||||||
|
*/
|
||||||
|
suspend fun cleanUp()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cleans up resources when the engine is no longer needed.
|
||||||
|
*/
|
||||||
|
fun destroy()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* States of the inference engine
|
||||||
|
*/
|
||||||
|
sealed class State {
|
||||||
|
object Uninitialized : State()
|
||||||
|
object Initializing : State()
|
||||||
|
object Initialized : State()
|
||||||
|
|
||||||
|
object LoadingModel : State()
|
||||||
|
object UnloadingModel : State()
|
||||||
|
object ModelReady : State()
|
||||||
|
|
||||||
|
object Benchmarking : State()
|
||||||
|
object ProcessingSystemPrompt : State()
|
||||||
|
object ProcessingUserPrompt : State()
|
||||||
|
|
||||||
|
object Generating : State()
|
||||||
|
|
||||||
|
data class Error(val exception: Exception) : State()
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
const val DEFAULT_PREDICT_LENGTH = 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val State.isUninterruptible
|
||||||
|
get() = this is State.Initializing ||
|
||||||
|
this is State.LoadingModel ||
|
||||||
|
this is State.UnloadingModel ||
|
||||||
|
this is State.Benchmarking ||
|
||||||
|
this is State.ProcessingSystemPrompt ||
|
||||||
|
this is State.ProcessingUserPrompt
|
||||||
|
|
||||||
|
val State.isModelLoaded: Boolean
|
||||||
|
get() = this is State.ModelReady ||
|
||||||
|
this is State.Benchmarking ||
|
||||||
|
this is State.ProcessingSystemPrompt ||
|
||||||
|
this is State.ProcessingUserPrompt ||
|
||||||
|
this is State.Generating
|
||||||
|
|
||||||
|
class UnsupportedArchitectureException : Exception()
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
package com.arm.aichat.gguf
|
||||||
|
|
||||||
|
import kotlin.collections.get
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
|
||||||
|
* The `label` matches what llama‑cli prints.
|
||||||
|
*/
|
||||||
|
enum class FileType(val code: Int, val label: String) {
|
||||||
|
ALL_F32(0, "all F32"),
|
||||||
|
MOSTLY_F16(1, "F16"),
|
||||||
|
MOSTLY_Q4_0(2, "Q4_0"),
|
||||||
|
MOSTLY_Q4_1(3, "Q4_1"),
|
||||||
|
// 4 removed
|
||||||
|
MOSTLY_Q8_0(7, "Q8_0"),
|
||||||
|
MOSTLY_Q5_0(8, "Q5_0"),
|
||||||
|
MOSTLY_Q5_1(9, "Q5_1"),
|
||||||
|
|
||||||
|
/* K‑quants ------------------------------------------------------------ */
|
||||||
|
MOSTLY_Q2_K (10, "Q2_K - Medium"),
|
||||||
|
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
|
||||||
|
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
|
||||||
|
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
|
||||||
|
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
|
||||||
|
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
|
||||||
|
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
|
||||||
|
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
|
||||||
|
MOSTLY_Q6_K (18, "Q6_K"),
|
||||||
|
|
||||||
|
/* IQ quants ----------------------------------------------------------- */
|
||||||
|
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
|
||||||
|
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
|
||||||
|
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
|
||||||
|
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
|
||||||
|
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
|
||||||
|
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
|
||||||
|
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
|
||||||
|
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
|
||||||
|
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
|
||||||
|
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
|
||||||
|
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
|
||||||
|
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
|
||||||
|
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
|
||||||
|
|
||||||
|
/* BF16 & Ternary ------------------------------------------------------ */
|
||||||
|
MOSTLY_BF16 (32, "BF16"),
|
||||||
|
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
|
||||||
|
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
|
||||||
|
|
||||||
|
/* Special flag -------------------------------------------------------- */
|
||||||
|
GUESSED(1024, "(guessed)"),
|
||||||
|
|
||||||
|
UNKNOWN(-1, "unknown");
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val map = entries.associateBy(FileType::code)
|
||||||
|
|
||||||
|
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,132 @@
|
||||||
|
package com.arm.aichat.gguf
|
||||||
|
|
||||||
|
import java.io.IOException
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Structured metadata of GGUF
|
||||||
|
*/
|
||||||
|
data class GgufMetadata(
|
||||||
|
// Basic file info
|
||||||
|
val version: GgufVersion,
|
||||||
|
val tensorCount: Long,
|
||||||
|
val kvCount: Long,
|
||||||
|
|
||||||
|
// General info
|
||||||
|
val basic: BasicInfo,
|
||||||
|
val author: AuthorInfo? = null,
|
||||||
|
val additional: AdditionalInfo? = null,
|
||||||
|
val architecture: ArchitectureInfo? = null,
|
||||||
|
val baseModels: List<BaseModelInfo>? = null,
|
||||||
|
val tokenizer: TokenizerInfo? = null,
|
||||||
|
|
||||||
|
// Derivative info
|
||||||
|
val dimensions: DimensionsInfo? = null,
|
||||||
|
val attention: AttentionInfo? = null,
|
||||||
|
val rope: RopeInfo? = null,
|
||||||
|
val experts: ExpertsInfo? = null
|
||||||
|
) {
|
||||||
|
enum class GgufVersion(val code: Int, val label: String) {
|
||||||
|
/** First public draft; little‑endian only, no alignment key. */
|
||||||
|
LEGACY_V1(1, "Legacy v1"),
|
||||||
|
|
||||||
|
/** Added split‑file support and some extra metadata keys. */
|
||||||
|
EXTENDED_V2(2, "Extended v2"),
|
||||||
|
|
||||||
|
/** Current spec: endian‑aware, mandatory alignment, fully validated. */
|
||||||
|
VALIDATED_V3(3, "Validated v3");
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun fromCode(code: Int): GgufVersion =
|
||||||
|
entries.firstOrNull { it.code == code }
|
||||||
|
?: throw IOException("Unknown GGUF version code $code")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "$label (code=$code)"
|
||||||
|
}
|
||||||
|
|
||||||
|
data class BasicInfo(
|
||||||
|
val uuid: String? = null,
|
||||||
|
val name: String? = null,
|
||||||
|
val nameLabel: String? = null,
|
||||||
|
val sizeLabel: String? = null, // Size label like "7B"
|
||||||
|
)
|
||||||
|
|
||||||
|
data class AuthorInfo(
|
||||||
|
val organization: String? = null,
|
||||||
|
val author: String? = null,
|
||||||
|
val doi: String? = null,
|
||||||
|
val url: String? = null,
|
||||||
|
val repoUrl: String? = null,
|
||||||
|
val license: String? = null,
|
||||||
|
val licenseLink: String? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class AdditionalInfo(
|
||||||
|
val type: String? = null,
|
||||||
|
val description: String? = null,
|
||||||
|
val tags: List<String>? = null,
|
||||||
|
val languages: List<String>? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class ArchitectureInfo(
|
||||||
|
val architecture: String? = null,
|
||||||
|
val fileType: Int? = null,
|
||||||
|
val vocabSize: Int? = null,
|
||||||
|
val finetune: String? = null,
|
||||||
|
val quantizationVersion: Int? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class BaseModelInfo(
|
||||||
|
val name: String? = null,
|
||||||
|
val author: String? = null,
|
||||||
|
val version: String? = null,
|
||||||
|
val organization: String? = null,
|
||||||
|
val url: String? = null,
|
||||||
|
val doi: String? = null,
|
||||||
|
val uuid: String? = null,
|
||||||
|
val repoUrl: String? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class TokenizerInfo(
|
||||||
|
val model: String? = null,
|
||||||
|
val bosTokenId: Int? = null,
|
||||||
|
val eosTokenId: Int? = null,
|
||||||
|
val unknownTokenId: Int? = null,
|
||||||
|
val paddingTokenId: Int? = null,
|
||||||
|
val addBosToken: Boolean? = null,
|
||||||
|
val addEosToken: Boolean? = null,
|
||||||
|
val chatTemplate: String? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class DimensionsInfo(
|
||||||
|
val contextLength: Int? = null,
|
||||||
|
val embeddingSize: Int? = null,
|
||||||
|
val blockCount: Int? = null,
|
||||||
|
val feedForwardSize: Int? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class AttentionInfo(
|
||||||
|
val headCount: Int? = null,
|
||||||
|
val headCountKv: Int? = null,
|
||||||
|
val keyLength: Int? = null,
|
||||||
|
val valueLength: Int? = null,
|
||||||
|
val layerNormEpsilon: Float? = null,
|
||||||
|
val layerNormRmsEpsilon: Float? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class RopeInfo(
|
||||||
|
val frequencyBase: Float? = null,
|
||||||
|
val dimensionCount: Int? = null,
|
||||||
|
val scalingType: String? = null,
|
||||||
|
val scalingFactor: Float? = null,
|
||||||
|
val attnFactor: Float? = null,
|
||||||
|
val originalContextLength: Int? = null,
|
||||||
|
val finetuned: Boolean? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class ExpertsInfo(
|
||||||
|
val count: Int? = null,
|
||||||
|
val usedCount: Int? = null,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
package com.arm.aichat.gguf
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.net.Uri
|
||||||
|
import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
|
||||||
|
import java.io.File
|
||||||
|
import java.io.IOException
|
||||||
|
import java.io.InputStream
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for reading GGUF metadata from model files.
|
||||||
|
* Use `GgufMetadataReader.create()` to get an instance.
|
||||||
|
*/
|
||||||
|
interface GgufMetadataReader {
|
||||||
|
/**
|
||||||
|
* Reads the magic number from the specified file path.
|
||||||
|
*
|
||||||
|
* @param file Java File to the GGUF file with absolute path
|
||||||
|
* @return true if file is valid GGUF, otherwise false
|
||||||
|
* @throws InvalidFileFormatException if file format is invalid
|
||||||
|
*/
|
||||||
|
suspend fun ensureSourceFileFormat(file: File): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads the magic number from the specified file path.
|
||||||
|
*
|
||||||
|
* @param context Context for obtaining [android.content.ContentProvider]
|
||||||
|
* @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
|
||||||
|
* @return true if file is valid GGUF, otherwise false
|
||||||
|
* @throws InvalidFileFormatException if file format is invalid
|
||||||
|
*/
|
||||||
|
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads and parses GGUF metadata from the specified file path.
|
||||||
|
*
|
||||||
|
* @param input the [InputStream] obtained from a readable file or content
|
||||||
|
* @return Structured metadata extracted from the file
|
||||||
|
* @throws IOException if file is damaged or cannot be read
|
||||||
|
* @throws InvalidFileFormatException if file format is invalid
|
||||||
|
*/
|
||||||
|
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val DEFAULT_SKIP_KEYS = setOf(
|
||||||
|
"tokenizer.chat_template",
|
||||||
|
"tokenizer.ggml.scores",
|
||||||
|
"tokenizer.ggml.tokens",
|
||||||
|
"tokenizer.ggml.token_type"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a default GgufMetadataReader instance
|
||||||
|
*/
|
||||||
|
fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
|
||||||
|
skipKeys = DEFAULT_SKIP_KEYS,
|
||||||
|
arraySummariseThreshold = 1_000
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a GgufMetadataReader with custom configuration
|
||||||
|
*
|
||||||
|
* @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
|
||||||
|
* @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised;
|
||||||
|
* If -1, never summarise.
|
||||||
|
*/
|
||||||
|
fun create(
|
||||||
|
skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
|
||||||
|
arraySummariseThreshold: Int = 1_000
|
||||||
|
): GgufMetadataReader = GgufMetadataReaderImpl(
|
||||||
|
skipKeys = skipKeys,
|
||||||
|
arraySummariseThreshold = arraySummariseThreshold
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class InvalidFileFormatException : IOException()
|
||||||
|
|
@ -0,0 +1,309 @@
|
||||||
|
package com.arm.aichat.internal
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.util.Log
|
||||||
|
import com.arm.aichat.InferenceEngine
|
||||||
|
import com.arm.aichat.UnsupportedArchitectureException
|
||||||
|
import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
|
||||||
|
import dalvik.annotation.optimization.FastNative
|
||||||
|
import kotlinx.coroutines.CancellationException
|
||||||
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||||
|
import kotlinx.coroutines.SupervisorJob
|
||||||
|
import kotlinx.coroutines.cancel
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.flow.flow
|
||||||
|
import kotlinx.coroutines.flow.flowOn
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import java.io.File
|
||||||
|
import java.io.IOException
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
|
||||||
|
*
|
||||||
|
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
|
||||||
|
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
|
||||||
|
* with the underlying C++ native code.
|
||||||
|
*
|
||||||
|
* The typical usage flow is:
|
||||||
|
* 1. Get instance via [getInstance]
|
||||||
|
* 2. Load a model with [loadModel]
|
||||||
|
* 3. Send prompts with [sendUserPrompt]
|
||||||
|
* 4. Generate responses as token streams
|
||||||
|
* 5. Perform [cleanUp] when done with a model
|
||||||
|
* 6. Properly [destroy] when completely done
|
||||||
|
*
|
||||||
|
* State transitions are managed automatically and validated at each operation.
|
||||||
|
*
|
||||||
|
* @see ai_chat.cpp for the native implementation details
|
||||||
|
*/
|
||||||
|
internal class InferenceEngineImpl private constructor(
|
||||||
|
private val nativeLibDir: String
|
||||||
|
) : InferenceEngine {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val TAG = InferenceEngineImpl::class.java.simpleName
|
||||||
|
|
||||||
|
@Volatile
|
||||||
|
private var instance: InferenceEngine? = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create or obtain [InferenceEngineImpl]'s single instance.
|
||||||
|
*
|
||||||
|
* @param Context for obtaining native library directory
|
||||||
|
* @throws IllegalArgumentException if native library path is invalid
|
||||||
|
* @throws UnsatisfiedLinkError if library failed to load
|
||||||
|
*/
|
||||||
|
internal fun getInstance(context: Context) =
|
||||||
|
instance ?: synchronized(this) {
|
||||||
|
val nativeLibDir = context.applicationInfo.nativeLibraryDir
|
||||||
|
require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
|
||||||
|
|
||||||
|
try {
|
||||||
|
Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
|
||||||
|
InferenceEngineImpl(nativeLibDir).also { instance = it }
|
||||||
|
} catch (e: UnsatisfiedLinkError) {
|
||||||
|
Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JNI methods
|
||||||
|
* @see ai_chat.cpp
|
||||||
|
*/
|
||||||
|
@FastNative
|
||||||
|
private external fun init(nativeLibDir: String)
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun load(modelPath: String): Int
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun prepare(): Int
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun systemInfo(): String
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun processSystemPrompt(systemPrompt: String): Int
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun generateNextToken(): String?
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun unload()
|
||||||
|
|
||||||
|
@FastNative
|
||||||
|
private external fun shutdown()
|
||||||
|
|
||||||
|
private val _state =
|
||||||
|
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
|
||||||
|
override val state: StateFlow<InferenceEngine.State> = _state
|
||||||
|
|
||||||
|
private var _readyForSystemPrompt = false
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
|
||||||
|
*/
|
||||||
|
@OptIn(ExperimentalCoroutinesApi::class)
|
||||||
|
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
|
||||||
|
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
|
||||||
|
|
||||||
|
init {
|
||||||
|
llamaScope.launch {
|
||||||
|
try {
|
||||||
|
check(_state.value is InferenceEngine.State.Uninitialized) {
|
||||||
|
"Cannot load native library in ${_state.value.javaClass.simpleName}!"
|
||||||
|
}
|
||||||
|
_state.value = InferenceEngine.State.Initializing
|
||||||
|
Log.i(TAG, "Loading native library...")
|
||||||
|
System.loadLibrary("ai-chat")
|
||||||
|
init(nativeLibDir)
|
||||||
|
_state.value = InferenceEngine.State.Initialized
|
||||||
|
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to load native library", e)
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the LLM
|
||||||
|
*/
|
||||||
|
override suspend fun loadModel(pathToModel: String) =
|
||||||
|
withContext(llamaDispatcher) {
|
||||||
|
check(_state.value is InferenceEngine.State.Initialized) {
|
||||||
|
"Cannot load model in ${_state.value.javaClass.simpleName}!"
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Log.i(TAG, "Checking access to model file... \n$pathToModel")
|
||||||
|
File(pathToModel).let {
|
||||||
|
require(it.exists()) { "File not found" }
|
||||||
|
require(it.isFile) { "Not a valid file" }
|
||||||
|
require(it.canRead()) { "Cannot read file" }
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
|
_state.value = InferenceEngine.State.LoadingModel
|
||||||
|
load(pathToModel).let {
|
||||||
|
// TODO-han.yin: find a better way to pass other error codes
|
||||||
|
if (it != 0) throw UnsupportedArchitectureException()
|
||||||
|
}
|
||||||
|
prepare().let {
|
||||||
|
if (it != 0) throw IOException("Failed to prepare resources")
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Model loaded!")
|
||||||
|
_readyForSystemPrompt = true
|
||||||
|
_state.value = InferenceEngine.State.ModelReady
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
|
||||||
|
_state.value = InferenceEngine.State.Error(e)
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process the plain text system prompt
|
||||||
|
*
|
||||||
|
* TODO-han.yin: return error code if system prompt not correct processed?
|
||||||
|
*/
|
||||||
|
override suspend fun setSystemPrompt(prompt: String) =
|
||||||
|
withContext(llamaDispatcher) {
|
||||||
|
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||||
|
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||||
|
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||||
|
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(TAG, "Sending system prompt...")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
|
_state.value = InferenceEngine.State.ProcessingSystemPrompt
|
||||||
|
processSystemPrompt(prompt).let { result ->
|
||||||
|
if (result != 0) {
|
||||||
|
RuntimeException("Failed to process system prompt: $result").also {
|
||||||
|
_state.value = InferenceEngine.State.Error(it)
|
||||||
|
throw it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
|
||||||
|
_state.value = InferenceEngine.State.ModelReady
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
|
||||||
|
*/
|
||||||
|
override fun sendUserPrompt(
|
||||||
|
message: String,
|
||||||
|
predictLength: Int,
|
||||||
|
): Flow<String> = flow {
|
||||||
|
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||||
|
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||||
|
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Log.i(TAG, "Sending user prompt...")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
|
_state.value = InferenceEngine.State.ProcessingUserPrompt
|
||||||
|
|
||||||
|
processUserPrompt(message, predictLength).let { result ->
|
||||||
|
if (result != 0) {
|
||||||
|
Log.e(TAG, "Failed to process user prompt: $result")
|
||||||
|
return@flow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
|
||||||
|
_state.value = InferenceEngine.State.Generating
|
||||||
|
while (true) {
|
||||||
|
generateNextToken()?.let { utf8token ->
|
||||||
|
if (utf8token.isNotEmpty()) emit(utf8token)
|
||||||
|
} ?: break
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
|
||||||
|
_state.value = InferenceEngine.State.ModelReady
|
||||||
|
} catch (e: CancellationException) {
|
||||||
|
Log.i(TAG, "Generation cancelled by user.")
|
||||||
|
_state.value = InferenceEngine.State.ModelReady
|
||||||
|
throw e
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error during generation!", e)
|
||||||
|
_state.value = InferenceEngine.State.Error(e)
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}.flowOn(llamaDispatcher)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Benchmark the model
|
||||||
|
*/
|
||||||
|
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||||
|
withContext(llamaDispatcher) {
|
||||||
|
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||||
|
"Benchmark request discarded due to: $state"
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
|
||||||
|
_readyForSystemPrompt = false // Just to be safe
|
||||||
|
_state.value = InferenceEngine.State.Benchmarking
|
||||||
|
benchModel(pp, tg, pl, nr).also {
|
||||||
|
_state.value = InferenceEngine.State.ModelReady
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unloads the model and frees resources, or reset error states
|
||||||
|
*/
|
||||||
|
override suspend fun cleanUp() =
|
||||||
|
withContext(llamaDispatcher) {
|
||||||
|
when (val state = _state.value) {
|
||||||
|
is InferenceEngine.State.ModelReady -> {
|
||||||
|
Log.i(TAG, "Unloading model and free resources...")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
|
_state.value = InferenceEngine.State.UnloadingModel
|
||||||
|
|
||||||
|
unload()
|
||||||
|
|
||||||
|
_state.value = InferenceEngine.State.Initialized
|
||||||
|
Log.i(TAG, "Model unloaded!")
|
||||||
|
Unit
|
||||||
|
}
|
||||||
|
|
||||||
|
is InferenceEngine.State.Error -> {
|
||||||
|
Log.i(TAG, "Resetting error states...")
|
||||||
|
_state.value = InferenceEngine.State.Initialized
|
||||||
|
Log.i(TAG, "States reset!")
|
||||||
|
Unit
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cancel all ongoing coroutines and free GGML backends
|
||||||
|
*/
|
||||||
|
override fun destroy() {
|
||||||
|
_readyForSystemPrompt = false
|
||||||
|
llamaScope.cancel()
|
||||||
|
when(_state.value) {
|
||||||
|
is InferenceEngine.State.Uninitialized -> {}
|
||||||
|
is InferenceEngine.State.Initialized -> shutdown()
|
||||||
|
else -> { unload(); shutdown() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,590 @@
|
||||||
|
package com.arm.aichat.internal.gguf
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.net.Uri
|
||||||
|
import com.arm.aichat.gguf.GgufMetadata
|
||||||
|
import com.arm.aichat.gguf.GgufMetadataReader
|
||||||
|
import com.arm.aichat.gguf.InvalidFileFormatException
|
||||||
|
import java.io.File
|
||||||
|
import java.io.IOException
|
||||||
|
import java.io.InputStream
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility class to read GGUF model files and extract metadata key-value pairs.
|
||||||
|
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
|
||||||
|
*/
|
||||||
|
internal class GgufMetadataReaderImpl(
|
||||||
|
private val skipKeys: Set<String>,
|
||||||
|
private val arraySummariseThreshold: Int,
|
||||||
|
) : GgufMetadataReader {
|
||||||
|
companion object {
|
||||||
|
private const val ARCH_LLAMA = "llama"
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
|
||||||
|
enum class MetadataType(val code: Int) {
|
||||||
|
UINT8(0), INT8(1), UINT16(2), INT16(3),
|
||||||
|
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
||||||
|
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
||||||
|
companion object {
|
||||||
|
private val codeMap = entries.associateBy(MetadataType::code)
|
||||||
|
fun fromCode(code: Int): MetadataType = codeMap[code]
|
||||||
|
?: throw IOException("Unknown metadata value type code: $code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
|
||||||
|
sealed class MetadataValue {
|
||||||
|
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
|
||||||
|
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
|
||||||
|
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
|
||||||
|
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
|
||||||
|
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
|
||||||
|
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
|
||||||
|
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
|
||||||
|
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
|
||||||
|
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
|
||||||
|
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
|
||||||
|
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
|
||||||
|
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
|
||||||
|
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
|
||||||
|
private fun MetadataValue.toPrimitive(): Any = when (this) {
|
||||||
|
is MetadataValue.UInt8 -> value
|
||||||
|
is MetadataValue.Int8 -> value
|
||||||
|
is MetadataValue.UInt16 -> value
|
||||||
|
is MetadataValue.Int16 -> value
|
||||||
|
is MetadataValue.UInt32 -> value
|
||||||
|
is MetadataValue.Int32 -> value
|
||||||
|
is MetadataValue.Float32 -> value
|
||||||
|
is MetadataValue.Bool -> value
|
||||||
|
is MetadataValue.StringVal -> value
|
||||||
|
is MetadataValue.UInt64 -> value
|
||||||
|
is MetadataValue.Int64 -> value
|
||||||
|
is MetadataValue.Float64 -> value
|
||||||
|
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads the magic number from the specified file path.
|
||||||
|
*
|
||||||
|
* @param context Context for obtaining ContentResolver
|
||||||
|
* @param uri Uri to the GGUF file provided by ContentProvider
|
||||||
|
* @return true if file is valid GGUF, otherwise false
|
||||||
|
*/
|
||||||
|
override suspend fun ensureSourceFileFormat(file: File): Boolean =
|
||||||
|
file.inputStream().buffered().use { ensureMagic(it) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads the magic number from the specified file path.
|
||||||
|
*
|
||||||
|
* @param context Context for obtaining ContentResolver
|
||||||
|
* @param uri Uri to the GGUF file provided by ContentProvider
|
||||||
|
* @return true if file is valid GGUF, otherwise false
|
||||||
|
*/
|
||||||
|
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
|
||||||
|
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
|
||||||
|
|
||||||
|
/** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
|
||||||
|
private fun ensureMagic(input: InputStream): Boolean =
|
||||||
|
ByteArray(4).let {
|
||||||
|
if (input.read(it) != 4) throw IOException("Not a valid file!")
|
||||||
|
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||||
|
* populated [GgufMetadata] tree.
|
||||||
|
*
|
||||||
|
* Steps performed internally:
|
||||||
|
* 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
|
||||||
|
* 2. Streams through the key‑value section, skipping large blobs if the key
|
||||||
|
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
|
||||||
|
* 3. Converts the resulting raw map into strongly‑typed sub‑structures
|
||||||
|
* (basic info, tokenizer, rope, etc.).
|
||||||
|
*
|
||||||
|
* The method is STREAMING‑ONLY: tensors are never mapped or loaded into
|
||||||
|
* memory, so even multi‑GB model files can be processed in < 50 ms.
|
||||||
|
*
|
||||||
|
* @param path Absolute or relative filesystem path to a `.gguf` file.
|
||||||
|
* @return A [GgufMetadata] instance containing all recognised metadata plus
|
||||||
|
* an `allMetadata` map with any keys that were not given a dedicated
|
||||||
|
* field.
|
||||||
|
* @throws IOException if the file is not GGUF, the version is unsupported,
|
||||||
|
* or the metadata block is truncated / corrupt.
|
||||||
|
*/
|
||||||
|
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
|
||||||
|
// ── 1. header ──────────────────────────────────────────────────────────
|
||||||
|
// throws on mismatch
|
||||||
|
val version = ensureMagicAndVersion(input)
|
||||||
|
val tensorCount = readLittleLong(input)
|
||||||
|
val kvCount = readLittleLong(input)
|
||||||
|
|
||||||
|
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||||
|
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||||
|
|
||||||
|
// ── 3. build structured object ────────────────────────────────────────
|
||||||
|
return buildStructured(meta, version, tensorCount, kvCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||||
|
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
||||||
|
if (!ensureMagic(input)) throw InvalidFileFormatException()
|
||||||
|
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read an unsigned 32‑bit little‑endian integer.
|
||||||
|
*
|
||||||
|
* @throws IOException if fewer than four bytes are available.
|
||||||
|
*/
|
||||||
|
private fun readLEUInt32(input: InputStream): Int {
|
||||||
|
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
|
||||||
|
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
|
||||||
|
return (b3 and 0xFF shl 24) or
|
||||||
|
(b2 and 0xFF shl 16) or
|
||||||
|
(b1 and 0xFF shl 8) or
|
||||||
|
(b0 and 0xFF)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Low‑level helper that reads the entire “key-value” section from the current
|
||||||
|
* stream position.
|
||||||
|
*
|
||||||
|
* @param input Open stream positioned JUST AFTER the header.
|
||||||
|
* @param kvCnt Number of key‑value pairs (taken from the header).
|
||||||
|
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
|
||||||
|
*
|
||||||
|
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
|
||||||
|
* [skipValue] or [parseValue] accordingly.
|
||||||
|
*/
|
||||||
|
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
|
||||||
|
mutableMapOf<String, MetadataValue>().apply {
|
||||||
|
repeat(kvCnt.toInt()) {
|
||||||
|
val key = readString(input)
|
||||||
|
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||||
|
if (key in skipKeys) {
|
||||||
|
skipValue(input, valueT)
|
||||||
|
} else {
|
||||||
|
this[key] = parseValue(input, valueT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
|
||||||
|
* [GgufMetadata] tree used by the rest of the app.
|
||||||
|
*
|
||||||
|
* Only the keys listed in the spec are copied into dedicated data classes;
|
||||||
|
* everything else is preserved in `GgufMetadata.allMetadata`.
|
||||||
|
*
|
||||||
|
* @param m Raw key/value map.
|
||||||
|
* @param version GGUF file‑format version (enum).
|
||||||
|
* @param tensorCnt Number of tensors (from the header).
|
||||||
|
* @param kvCnt Total metadata pair count (from the header).
|
||||||
|
*/
|
||||||
|
private fun buildStructured(
|
||||||
|
m: Map<String, MetadataValue>,
|
||||||
|
version: GgufMetadata.GgufVersion,
|
||||||
|
tensorCnt: Long,
|
||||||
|
kvCnt: Long
|
||||||
|
): GgufMetadata {
|
||||||
|
// ---------- helpers ----------
|
||||||
|
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
|
||||||
|
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
|
||||||
|
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
|
||||||
|
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
|
||||||
|
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
|
||||||
|
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
|
||||||
|
fun String.strList(): List<String>? =
|
||||||
|
(m[this] as? MetadataValue.ArrayVal)
|
||||||
|
?.elements
|
||||||
|
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
|
||||||
|
|
||||||
|
val arch = "general.architecture".str() ?: ARCH_LLAMA
|
||||||
|
|
||||||
|
// -------------- populate sections ----------------
|
||||||
|
val basic = GgufMetadata.BasicInfo(
|
||||||
|
uuid = "general.uuid".str(),
|
||||||
|
name = "general.basename".str(),
|
||||||
|
nameLabel = "general.name".str(),
|
||||||
|
sizeLabel = "general.size_label".str()
|
||||||
|
)
|
||||||
|
|
||||||
|
val author = GgufMetadata.AuthorInfo(
|
||||||
|
organization = "general.organization".str(),
|
||||||
|
author = "general.author".str(),
|
||||||
|
doi = "general.doi".str(),
|
||||||
|
url = "general.url".str(),
|
||||||
|
repoUrl = "general.repo_url".str(),
|
||||||
|
license = "general.license".str(),
|
||||||
|
licenseLink = "general.license.link".str()
|
||||||
|
).takeUnless {
|
||||||
|
organization == null && author == null && doi == null &&
|
||||||
|
url == null && repoUrl == null && license == null && licenseLink == null
|
||||||
|
}
|
||||||
|
|
||||||
|
val additional = GgufMetadata.AdditionalInfo(
|
||||||
|
type = "general.type".str(),
|
||||||
|
description = "general.description".str(),
|
||||||
|
tags = "general.tags".strList(),
|
||||||
|
languages = "general.languages".strList()
|
||||||
|
).takeUnless {
|
||||||
|
type == null && description == null && tags == null && languages == null
|
||||||
|
}
|
||||||
|
|
||||||
|
val architectureInfo = GgufMetadata.ArchitectureInfo(
|
||||||
|
architecture = arch,
|
||||||
|
fileType = "general.file_type".u32(),
|
||||||
|
vocabSize = "$arch.vocab_size".u32(),
|
||||||
|
finetune = "general.finetune".str(),
|
||||||
|
quantizationVersion = "general.quantization_version".u32()
|
||||||
|
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
|
||||||
|
|
||||||
|
val baseModels = buildList {
|
||||||
|
val n = "general.base_model.count".u32() ?: 0
|
||||||
|
for (i in 0 until n) {
|
||||||
|
fun k(s: String) = "general.base_model.$i.$s"
|
||||||
|
add(
|
||||||
|
GgufMetadata.BaseModelInfo(
|
||||||
|
name = k("name").str(),
|
||||||
|
author = k("author").str(),
|
||||||
|
version = k("version").str(),
|
||||||
|
organization = k("organization").str(),
|
||||||
|
url = k("url").str(),
|
||||||
|
doi = k("doi").str(),
|
||||||
|
uuid = k("uuid").str(),
|
||||||
|
repoUrl = k("repo_url").str(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}.takeIf { it.isNotEmpty() }
|
||||||
|
|
||||||
|
val tokenizer = GgufMetadata.TokenizerInfo(
|
||||||
|
model = "tokenizer.ggml.model".str(),
|
||||||
|
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
|
||||||
|
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
|
||||||
|
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
|
||||||
|
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
|
||||||
|
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
|
||||||
|
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
|
||||||
|
chatTemplate = "tokenizer.chat_template".str()
|
||||||
|
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
|
||||||
|
unknownTokenId == null && paddingTokenId == null &&
|
||||||
|
addBosToken == null && addEosToken == null && chatTemplate == null
|
||||||
|
}
|
||||||
|
|
||||||
|
val dimensions = GgufMetadata.DimensionsInfo(
|
||||||
|
contextLength = "$arch.context_length".u32(),
|
||||||
|
embeddingSize = "$arch.embedding_length".u32(),
|
||||||
|
blockCount = "$arch.block_count".u32(),
|
||||||
|
feedForwardSize = "$arch.feed_forward_length".u32()
|
||||||
|
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
|
||||||
|
|
||||||
|
val attention = GgufMetadata.AttentionInfo(
|
||||||
|
headCount = "$arch.attention.head_count".u32(),
|
||||||
|
headCountKv = "$arch.attention.head_count_kv".u32(),
|
||||||
|
keyLength = "$arch.attention.key_length".u32(),
|
||||||
|
valueLength = "$arch.attention.value_length".u32(),
|
||||||
|
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
|
||||||
|
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
|
||||||
|
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
|
||||||
|
layerNormEpsilon == null && layerNormRmsEpsilon == null
|
||||||
|
}
|
||||||
|
|
||||||
|
val rope = GgufMetadata.RopeInfo(
|
||||||
|
frequencyBase = "$arch.rope.freq_base".f32(),
|
||||||
|
dimensionCount = "$arch.rope.dimension_count".u32(),
|
||||||
|
scalingType = "$arch.rope.scaling.type".str(),
|
||||||
|
scalingFactor = "$arch.rope.scaling.factor".f32(),
|
||||||
|
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
|
||||||
|
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
|
||||||
|
finetuned = "$arch.rope.scaling.finetuned".bool()
|
||||||
|
).takeUnless { frequencyBase == null && dimensionCount == null &&
|
||||||
|
scalingType == null && scalingFactor == null && attnFactor == null &&
|
||||||
|
originalContextLength == null && finetuned == null
|
||||||
|
}
|
||||||
|
|
||||||
|
val experts = GgufMetadata.ExpertsInfo(
|
||||||
|
count = "$arch.expert_count".u32(),
|
||||||
|
usedCount = "$arch.expert_used_count".u32()
|
||||||
|
).takeUnless { count == null && usedCount == null }
|
||||||
|
|
||||||
|
return GgufMetadata(
|
||||||
|
version = version,
|
||||||
|
tensorCount = tensorCnt,
|
||||||
|
kvCount = kvCnt,
|
||||||
|
basic = basic,
|
||||||
|
author = author,
|
||||||
|
additional = additional,
|
||||||
|
architecture = architectureInfo,
|
||||||
|
baseModels = baseModels,
|
||||||
|
tokenizer = tokenizer,
|
||||||
|
dimensions = dimensions,
|
||||||
|
attention = attention,
|
||||||
|
rope = rope,
|
||||||
|
experts = experts
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recursively parses a metadata value of the given type from the input stream.
|
||||||
|
* @param input The input stream positioned at the start of the value.
|
||||||
|
* @param type The metadata value type to parse.
|
||||||
|
*/
|
||||||
|
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
|
||||||
|
MetadataType.UINT8 -> {
|
||||||
|
// 1-byte unsigned integer
|
||||||
|
val byteVal = input.read()
|
||||||
|
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
|
||||||
|
MetadataValue.UInt8(byteVal.toUByte())
|
||||||
|
}
|
||||||
|
MetadataType.INT8 -> {
|
||||||
|
// 1-byte signed integer
|
||||||
|
val byteVal = input.read()
|
||||||
|
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
|
||||||
|
MetadataValue.Int8(byteVal.toByte())
|
||||||
|
}
|
||||||
|
MetadataType.UINT16 -> {
|
||||||
|
// 2-byte unsigned integer (little-endian)
|
||||||
|
val bytes = ByteArray(2)
|
||||||
|
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
|
||||||
|
// Combine two bytes (little-endian) into an unsigned 16-bit value
|
||||||
|
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||||
|
MetadataValue.UInt16(u16.toUShort())
|
||||||
|
}
|
||||||
|
MetadataType.INT16 -> {
|
||||||
|
// 2-byte signed integer (little-endian)
|
||||||
|
val bytes = ByteArray(2)
|
||||||
|
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
|
||||||
|
// Combine to 16-bit and interpret as signed
|
||||||
|
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||||
|
MetadataValue.Int16(i16.toShort())
|
||||||
|
}
|
||||||
|
MetadataType.UINT32 -> {
|
||||||
|
// 4-byte unsigned integer (little-endian)
|
||||||
|
val bytes = ByteArray(4)
|
||||||
|
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
|
||||||
|
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
|
||||||
|
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
|
||||||
|
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||||
|
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||||
|
(bytes[0].toLong() and 0xFFL)
|
||||||
|
MetadataValue.UInt32(u32.toUInt())
|
||||||
|
}
|
||||||
|
MetadataType.INT32 -> {
|
||||||
|
// 4-byte signed integer (little-endian)
|
||||||
|
val bytes = ByteArray(4)
|
||||||
|
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
|
||||||
|
// Combine four bytes into a 32-bit signed int
|
||||||
|
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
|
||||||
|
(bytes[2].toInt() and 0xFF shl 16) or
|
||||||
|
(bytes[1].toInt() and 0xFF shl 8) or
|
||||||
|
(bytes[0].toInt() and 0xFF)
|
||||||
|
MetadataValue.Int32(i32)
|
||||||
|
}
|
||||||
|
MetadataType.FLOAT32 -> {
|
||||||
|
// 4-byte IEEE 754 float (little-endian)
|
||||||
|
val bytes = ByteArray(4)
|
||||||
|
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
|
||||||
|
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
|
||||||
|
val bits = (bytes[3].toInt() and 0xFF shl 24) or
|
||||||
|
(bytes[2].toInt() and 0xFF shl 16) or
|
||||||
|
(bytes[1].toInt() and 0xFF shl 8) or
|
||||||
|
(bytes[0].toInt() and 0xFF)
|
||||||
|
val floatVal = Float.fromBits(bits)
|
||||||
|
MetadataValue.Float32(floatVal)
|
||||||
|
}
|
||||||
|
MetadataType.BOOL -> {
|
||||||
|
// 1-byte boolean (0 = false, 1 = true)
|
||||||
|
val byteVal = input.read()
|
||||||
|
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
|
||||||
|
if (byteVal != 0 && byteVal != 1) {
|
||||||
|
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
|
||||||
|
}
|
||||||
|
MetadataValue.Bool(byteVal != 0)
|
||||||
|
}
|
||||||
|
MetadataType.STRING -> {
|
||||||
|
// UTF-8 string (length-prefixed with 8-byte length)
|
||||||
|
val str = readString(input)
|
||||||
|
MetadataValue.StringVal(str)
|
||||||
|
}
|
||||||
|
MetadataType.ARRAY -> {
|
||||||
|
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||||
|
val len = readLittleLong(input)
|
||||||
|
val count = len.toInt()
|
||||||
|
|
||||||
|
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
|
||||||
|
// fast‑forward without allocation
|
||||||
|
repeat(count) { skipValue(input, elemType) }
|
||||||
|
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
|
||||||
|
} else {
|
||||||
|
val list = ArrayList<MetadataValue>(count)
|
||||||
|
repeat(count) { list += parseValue(input, elemType) }
|
||||||
|
MetadataValue.ArrayVal(elemType, list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MetadataType.UINT64 -> {
|
||||||
|
// 8-byte unsigned integer (little-endian)
|
||||||
|
val bytes = ByteArray(8)
|
||||||
|
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
|
||||||
|
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
|
||||||
|
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
|
||||||
|
(bytes[6].toULong() and 0xFFuL shl 48) or
|
||||||
|
(bytes[5].toULong() and 0xFFuL shl 40) or
|
||||||
|
(bytes[4].toULong() and 0xFFuL shl 32) or
|
||||||
|
(bytes[3].toULong() and 0xFFuL shl 24) or
|
||||||
|
(bytes[2].toULong() and 0xFFuL shl 16) or
|
||||||
|
(bytes[1].toULong() and 0xFFuL shl 8) or
|
||||||
|
(bytes[0].toULong() and 0xFFuL)
|
||||||
|
MetadataValue.UInt64(u64)
|
||||||
|
}
|
||||||
|
MetadataType.INT64 -> {
|
||||||
|
// 8-byte signed integer (little-endian)
|
||||||
|
val bytes = ByteArray(8)
|
||||||
|
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
|
||||||
|
// Combine 8 bytes into a signed 64-bit value (Long)
|
||||||
|
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||||
|
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||||
|
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||||
|
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||||
|
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||||
|
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||||
|
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||||
|
(bytes[0].toLong() and 0xFFL)
|
||||||
|
MetadataValue.Int64(i64)
|
||||||
|
}
|
||||||
|
MetadataType.FLOAT64 -> {
|
||||||
|
// 8-byte IEEE 754 double (little-endian)
|
||||||
|
val bytes = ByteArray(8)
|
||||||
|
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
|
||||||
|
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
|
||||||
|
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||||
|
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||||
|
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||||
|
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||||
|
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||||
|
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||||
|
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||||
|
(bytes[0].toLong() and 0xFFL)
|
||||||
|
val doubleVal = Double.fromBits(bits)
|
||||||
|
MetadataValue.Float64(doubleVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
|
||||||
|
this?.takeIf { !it.check() }
|
||||||
|
|
||||||
|
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
|
||||||
|
private fun skipValue(input: InputStream, type: MetadataType) {
|
||||||
|
when (type) {
|
||||||
|
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
|
||||||
|
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
|
||||||
|
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
|
||||||
|
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
|
||||||
|
MetadataType.STRING -> {
|
||||||
|
val len = readLittleLong(input); input.skipFully(len)
|
||||||
|
}
|
||||||
|
MetadataType.ARRAY -> {
|
||||||
|
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||||
|
val len = readLittleLong(input)
|
||||||
|
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
|
||||||
|
private fun readLittleLong(input: InputStream): Long {
|
||||||
|
val bytes = ByteArray(8)
|
||||||
|
input.readFully(bytes)
|
||||||
|
|
||||||
|
// Combine 8 bytes into a 64-bit value (Little Endian).
|
||||||
|
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
|
||||||
|
// In our context (lengths/counts), such extremely large values are not expected.
|
||||||
|
return (bytes[7].toLong() and 0xFFL shl 56) or
|
||||||
|
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||||
|
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||||
|
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||||
|
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||||
|
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||||
|
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||||
|
(bytes[0].toLong() and 0xFFL)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
|
||||||
|
private fun readString(input: InputStream): String =
|
||||||
|
// Read 8-byte little-endian length (number of bytes in the string).
|
||||||
|
readLittleLong(input).let { len ->
|
||||||
|
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||||
|
|
||||||
|
// Read the UTF-8 bytes of the given length.
|
||||||
|
ByteArray(len.toInt()).let {
|
||||||
|
if (it.isNotEmpty()) input.readFully(it)
|
||||||
|
String(it, Charsets.UTF_8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
|
||||||
|
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
|
||||||
|
// Note: assumes bytes length is 4.
|
||||||
|
(bytes[3].toInt() and 0xFF shl 24) or
|
||||||
|
(bytes[2].toInt() and 0xFF shl 16) or
|
||||||
|
(bytes[1].toInt() and 0xFF shl 8) or
|
||||||
|
(bytes[0].toInt() and 0xFF)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Robust skip that works the same on JDK 11 and Android’s desugared runtime.
|
||||||
|
*
|
||||||
|
* @param n Number of bytes to advance in the stream.
|
||||||
|
* @throws IOException on premature EOF.
|
||||||
|
*/
|
||||||
|
private fun InputStream.skipFully(n: Long) {
|
||||||
|
var remaining = n
|
||||||
|
val scratch = ByteArray(8192) // read‑and‑toss buffer
|
||||||
|
while (remaining > 0) {
|
||||||
|
val skipped = skip(remaining)
|
||||||
|
when {
|
||||||
|
skipped > 0 -> remaining -= skipped // normal fast path
|
||||||
|
skipped == 0L -> {
|
||||||
|
// fallback: read and discard
|
||||||
|
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
|
||||||
|
if (read == -1) throw IOException("EOF while skipping $n bytes")
|
||||||
|
remaining -= read
|
||||||
|
}
|
||||||
|
else -> throw IOException("Skip returned negative value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extension that keeps reading until the requested number of bytes are filled.
|
||||||
|
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
|
||||||
|
* streams.
|
||||||
|
*
|
||||||
|
* @param buf Destination buffer.
|
||||||
|
* @param len Number of bytes to fill (defaults to `buf.size`).
|
||||||
|
* @throws IOException on premature EOF.
|
||||||
|
*/
|
||||||
|
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
|
||||||
|
var off = 0
|
||||||
|
while (off < len) {
|
||||||
|
val n = read(buf, off, len - off)
|
||||||
|
if (n == -1) throw IOException("EOF after $off of $len bytes")
|
||||||
|
off += n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
|
||||||
|
* This is used for small fixed‑length reads (e.g. 4‑byte type codes).
|
||||||
|
*
|
||||||
|
* @throws IOException on premature EOF.
|
||||||
|
*/
|
||||||
|
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
|
||||||
|
if (read(it) != n) throw IOException("Unexpected EOF")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,71 +0,0 @@
|
||||||
plugins {
|
|
||||||
id("com.android.library")
|
|
||||||
id("org.jetbrains.kotlin.android")
|
|
||||||
}
|
|
||||||
|
|
||||||
android {
|
|
||||||
namespace = "android.llama.cpp"
|
|
||||||
compileSdk = 34
|
|
||||||
|
|
||||||
defaultConfig {
|
|
||||||
minSdk = 33
|
|
||||||
|
|
||||||
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
|
|
||||||
consumerProguardFiles("consumer-rules.pro")
|
|
||||||
ndk {
|
|
||||||
// Add NDK properties if wanted, e.g.
|
|
||||||
// abiFilters += listOf("arm64-v8a")
|
|
||||||
}
|
|
||||||
externalNativeBuild {
|
|
||||||
cmake {
|
|
||||||
arguments += "-DLLAMA_CURL=OFF"
|
|
||||||
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
|
||||||
arguments += "-DGGML_LLAMAFILE=OFF"
|
|
||||||
arguments += "-DCMAKE_BUILD_TYPE=Release"
|
|
||||||
cppFlags += listOf()
|
|
||||||
arguments += listOf()
|
|
||||||
|
|
||||||
cppFlags("")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buildTypes {
|
|
||||||
release {
|
|
||||||
isMinifyEnabled = false
|
|
||||||
proguardFiles(
|
|
||||||
getDefaultProguardFile("proguard-android-optimize.txt"),
|
|
||||||
"proguard-rules.pro"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
externalNativeBuild {
|
|
||||||
cmake {
|
|
||||||
path("src/main/cpp/CMakeLists.txt")
|
|
||||||
version = "3.22.1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
compileOptions {
|
|
||||||
sourceCompatibility = JavaVersion.VERSION_1_8
|
|
||||||
targetCompatibility = JavaVersion.VERSION_1_8
|
|
||||||
}
|
|
||||||
kotlinOptions {
|
|
||||||
jvmTarget = "1.8"
|
|
||||||
}
|
|
||||||
|
|
||||||
packaging {
|
|
||||||
resources {
|
|
||||||
excludes += "/META-INF/{AL2.0,LGPL2.1}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dependencies {
|
|
||||||
|
|
||||||
implementation("androidx.core:core-ktx:1.12.0")
|
|
||||||
implementation("androidx.appcompat:appcompat:1.6.1")
|
|
||||||
implementation("com.google.android.material:material:1.11.0")
|
|
||||||
testImplementation("junit:junit:4.13.2")
|
|
||||||
androidTestImplementation("androidx.test.ext:junit:1.1.5")
|
|
||||||
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
|
|
||||||
}
|
|
||||||
|
|
@ -1,53 +0,0 @@
|
||||||
# For more information about using CMake with Android Studio, read the
|
|
||||||
# documentation: https://d.android.com/studio/projects/add-native-code.html.
|
|
||||||
# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
|
|
||||||
|
|
||||||
# Sets the minimum CMake version required for this project.
|
|
||||||
cmake_minimum_required(VERSION 3.22.1)
|
|
||||||
|
|
||||||
# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
|
|
||||||
# Since this is the top level CMakeLists.txt, the project name is also accessible
|
|
||||||
# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
|
|
||||||
# build script scope).
|
|
||||||
project("llama-android")
|
|
||||||
|
|
||||||
#include(FetchContent)
|
|
||||||
#FetchContent_Declare(
|
|
||||||
# llama
|
|
||||||
# GIT_REPOSITORY https://github.com/ggml-org/llama.cpp
|
|
||||||
# GIT_TAG master
|
|
||||||
#)
|
|
||||||
|
|
||||||
# Also provides "common"
|
|
||||||
#FetchContent_MakeAvailable(llama)
|
|
||||||
|
|
||||||
# Creates and names a library, sets it as either STATIC
|
|
||||||
# or SHARED, and provides the relative paths to its source code.
|
|
||||||
# You can define multiple libraries, and CMake builds them for you.
|
|
||||||
# Gradle automatically packages shared libraries with your APK.
|
|
||||||
#
|
|
||||||
# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
|
|
||||||
# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
|
|
||||||
# is preferred for the same purpose.
|
|
||||||
#
|
|
||||||
|
|
||||||
#load local llama.cpp
|
|
||||||
add_subdirectory(../../../../../../ build-llama)
|
|
||||||
|
|
||||||
# In order to load a library into your app from Java/Kotlin, you must call
|
|
||||||
# System.loadLibrary() and pass the name of the library defined here;
|
|
||||||
# for GameActivity/NativeActivity derived applications, the same library name must be
|
|
||||||
# used in the AndroidManifest.xml file.
|
|
||||||
add_library(${CMAKE_PROJECT_NAME} SHARED
|
|
||||||
# List C/C++ source files with relative paths to this CMakeLists.txt.
|
|
||||||
llama-android.cpp)
|
|
||||||
|
|
||||||
# Specifies libraries CMake should link to your target library. You
|
|
||||||
# can link libraries from various origins, such as libraries defined in this
|
|
||||||
# build script, prebuilt third-party libraries, or Android system libraries.
|
|
||||||
target_link_libraries(${CMAKE_PROJECT_NAME}
|
|
||||||
# List libraries link to the target library
|
|
||||||
llama
|
|
||||||
common
|
|
||||||
android
|
|
||||||
log)
|
|
||||||
|
|
@ -1,452 +0,0 @@
|
||||||
#include <android/log.h>
|
|
||||||
#include <jni.h>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <math.h>
|
|
||||||
#include <string>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include "llama.h"
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
// Write C++ code here.
|
|
||||||
//
|
|
||||||
// Do not forget to dynamically load the C++ library into your application.
|
|
||||||
//
|
|
||||||
// For instance,
|
|
||||||
//
|
|
||||||
// In MainActivity.java:
|
|
||||||
// static {
|
|
||||||
// System.loadLibrary("llama-android");
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Or, in MainActivity.kt:
|
|
||||||
// companion object {
|
|
||||||
// init {
|
|
||||||
// System.loadLibrary("llama-android")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
#define TAG "llama-android.cpp"
|
|
||||||
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
|
||||||
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
|
|
||||||
|
|
||||||
jclass la_int_var;
|
|
||||||
jmethodID la_int_var_value;
|
|
||||||
jmethodID la_int_var_inc;
|
|
||||||
|
|
||||||
std::string cached_token_chars;
|
|
||||||
|
|
||||||
bool is_valid_utf8(const char * string) {
|
|
||||||
if (!string) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const unsigned char * bytes = (const unsigned char *)string;
|
|
||||||
int num;
|
|
||||||
|
|
||||||
while (*bytes != 0x00) {
|
|
||||||
if ((*bytes & 0x80) == 0x00) {
|
|
||||||
// U+0000 to U+007F
|
|
||||||
num = 1;
|
|
||||||
} else if ((*bytes & 0xE0) == 0xC0) {
|
|
||||||
// U+0080 to U+07FF
|
|
||||||
num = 2;
|
|
||||||
} else if ((*bytes & 0xF0) == 0xE0) {
|
|
||||||
// U+0800 to U+FFFF
|
|
||||||
num = 3;
|
|
||||||
} else if ((*bytes & 0xF8) == 0xF0) {
|
|
||||||
// U+10000 to U+10FFFF
|
|
||||||
num = 4;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bytes += 1;
|
|
||||||
for (int i = 1; i < num; ++i) {
|
|
||||||
if ((*bytes & 0xC0) != 0x80) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bytes += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
|
|
||||||
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
|
||||||
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
|
||||||
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
|
|
||||||
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jlong JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
|
||||||
|
|
||||||
auto path_to_model = env->GetStringUTFChars(filename, 0);
|
|
||||||
LOGi("Loading model from %s", path_to_model);
|
|
||||||
|
|
||||||
auto model = llama_model_load_from_file(path_to_model, model_params);
|
|
||||||
env->ReleaseStringUTFChars(filename, path_to_model);
|
|
||||||
|
|
||||||
if (!model) {
|
|
||||||
LOGe("load_model() failed");
|
|
||||||
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(model);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
|
|
||||||
llama_model_free(reinterpret_cast<llama_model *>(model));
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jlong JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
|
|
||||||
auto model = reinterpret_cast<llama_model *>(jmodel);
|
|
||||||
|
|
||||||
if (!model) {
|
|
||||||
LOGe("new_context(): model cannot be null");
|
|
||||||
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
|
|
||||||
LOGi("Using %d threads", n_threads);
|
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
|
||||||
|
|
||||||
ctx_params.n_ctx = 2048;
|
|
||||||
ctx_params.n_threads = n_threads;
|
|
||||||
ctx_params.n_threads_batch = n_threads;
|
|
||||||
|
|
||||||
llama_context * context = llama_new_context_with_model(model, ctx_params);
|
|
||||||
|
|
||||||
if (!context) {
|
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
|
||||||
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
|
|
||||||
"llama_new_context_with_model() returned null)");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
|
|
||||||
llama_free(reinterpret_cast<llama_context *>(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
|
|
||||||
llama_backend_free();
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
|
|
||||||
llama_log_set(log_callback, NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jstring JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|
||||||
JNIEnv *env,
|
|
||||||
jobject,
|
|
||||||
jlong context_pointer,
|
|
||||||
jlong model_pointer,
|
|
||||||
jlong batch_pointer,
|
|
||||||
jint pp,
|
|
||||||
jint tg,
|
|
||||||
jint pl,
|
|
||||||
jint nr
|
|
||||||
) {
|
|
||||||
auto pp_avg = 0.0;
|
|
||||||
auto tg_avg = 0.0;
|
|
||||||
auto pp_std = 0.0;
|
|
||||||
auto tg_std = 0.0;
|
|
||||||
|
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
|
||||||
const auto model = reinterpret_cast<llama_model *>(model_pointer);
|
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(context);
|
|
||||||
|
|
||||||
LOGi("n_ctx = %d", n_ctx);
|
|
||||||
|
|
||||||
int i, j;
|
|
||||||
int nri;
|
|
||||||
for (nri = 0; nri < nr; nri++) {
|
|
||||||
LOGi("Benchmark prompt processing (pp)");
|
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
|
||||||
|
|
||||||
const int n_tokens = pp;
|
|
||||||
for (i = 0; i < n_tokens; i++) {
|
|
||||||
common_batch_add(*batch, 0, i, { 0 }, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
|
||||||
LOGi("llama_decode() failed during prompt processing");
|
|
||||||
}
|
|
||||||
const auto t_pp_end = ggml_time_us();
|
|
||||||
|
|
||||||
// bench text generation
|
|
||||||
|
|
||||||
LOGi("Benchmark text generation (tg)");
|
|
||||||
|
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
|
||||||
const auto t_tg_start = ggml_time_us();
|
|
||||||
for (i = 0; i < tg; i++) {
|
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
|
||||||
for (j = 0; j < pl; j++) {
|
|
||||||
common_batch_add(*batch, 0, i, { j }, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
LOGi("llama_decode() text generation: %d", i);
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
|
||||||
LOGi("llama_decode() failed during text generation");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto t_tg_end = ggml_time_us();
|
|
||||||
|
|
||||||
llama_memory_clear(llama_get_memory(context), false);
|
|
||||||
|
|
||||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
|
||||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
|
||||||
|
|
||||||
const auto speed_pp = double(pp) / t_pp;
|
|
||||||
const auto speed_tg = double(pl * tg) / t_tg;
|
|
||||||
|
|
||||||
pp_avg += speed_pp;
|
|
||||||
tg_avg += speed_tg;
|
|
||||||
|
|
||||||
pp_std += speed_pp * speed_pp;
|
|
||||||
tg_std += speed_tg * speed_tg;
|
|
||||||
|
|
||||||
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
|
|
||||||
}
|
|
||||||
|
|
||||||
pp_avg /= double(nr);
|
|
||||||
tg_avg /= double(nr);
|
|
||||||
|
|
||||||
if (nr > 1) {
|
|
||||||
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
|
|
||||||
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
|
|
||||||
} else {
|
|
||||||
pp_std = 0;
|
|
||||||
tg_std = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
char model_desc[128];
|
|
||||||
llama_model_desc(model, model_desc, sizeof(model_desc));
|
|
||||||
|
|
||||||
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
|
|
||||||
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
|
|
||||||
|
|
||||||
const auto backend = "(Android)"; // TODO: What should this be?
|
|
||||||
|
|
||||||
std::stringstream result;
|
|
||||||
result << std::setprecision(2);
|
|
||||||
result << "| model | size | params | backend | test | t/s |\n";
|
|
||||||
result << "| --- | --- | --- | --- | --- | --- |\n";
|
|
||||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
|
||||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
|
||||||
|
|
||||||
return env->NewStringUTF(result.str().c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jlong JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
|
||||||
|
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
|
||||||
|
|
||||||
llama_batch *batch = new llama_batch {
|
|
||||||
0,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (embd) {
|
|
||||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
|
||||||
} else {
|
|
||||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
|
||||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
|
||||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
|
||||||
}
|
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
|
||||||
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
delete batch;
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jlong JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
|
||||||
auto sparams = llama_sampler_chain_default_params();
|
|
||||||
sparams.no_perf = true;
|
|
||||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(smpl);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
|
|
||||||
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
|
|
||||||
llama_backend_init();
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jstring JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
|
|
||||||
return env->NewStringUTF(llama_print_system_info());
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jint JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|
||||||
JNIEnv *env,
|
|
||||||
jobject,
|
|
||||||
jlong context_pointer,
|
|
||||||
jlong batch_pointer,
|
|
||||||
jstring jtext,
|
|
||||||
jboolean format_chat,
|
|
||||||
jint n_len
|
|
||||||
) {
|
|
||||||
|
|
||||||
cached_token_chars.clear();
|
|
||||||
|
|
||||||
const auto text = env->GetStringUTFChars(jtext, 0);
|
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
|
|
||||||
bool parse_special = (format_chat == JNI_TRUE);
|
|
||||||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
|
||||||
|
|
||||||
auto n_ctx = llama_n_ctx(context);
|
|
||||||
auto n_kv_req = tokens_list.size() + n_len;
|
|
||||||
|
|
||||||
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
|
|
||||||
|
|
||||||
if (n_kv_req > n_ctx) {
|
|
||||||
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto id : tokens_list) {
|
|
||||||
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
|
||||||
}
|
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
|
||||||
|
|
||||||
// evaluate the initial prompt
|
|
||||||
for (auto i = 0; i < tokens_list.size(); i++) {
|
|
||||||
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
|
||||||
LOGe("llama_decode() failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
env->ReleaseStringUTFChars(jtext, text);
|
|
||||||
|
|
||||||
return batch->n_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT jstring JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|
||||||
JNIEnv * env,
|
|
||||||
jobject,
|
|
||||||
jlong context_pointer,
|
|
||||||
jlong batch_pointer,
|
|
||||||
jlong sampler_pointer,
|
|
||||||
jint n_len,
|
|
||||||
jobject intvar_ncur
|
|
||||||
) {
|
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
|
||||||
const auto model = llama_get_model(context);
|
|
||||||
const auto vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
|
||||||
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
|
||||||
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
|
||||||
|
|
||||||
// sample the most likely token
|
|
||||||
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
|
|
||||||
|
|
||||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
|
||||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_token_chars = common_token_to_piece(context, new_token_id);
|
|
||||||
cached_token_chars += new_token_chars;
|
|
||||||
|
|
||||||
jstring new_token = nullptr;
|
|
||||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
|
||||||
new_token = env->NewStringUTF(cached_token_chars.c_str());
|
|
||||||
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
|
|
||||||
cached_token_chars.clear();
|
|
||||||
} else {
|
|
||||||
new_token = env->NewStringUTF("");
|
|
||||||
}
|
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
|
||||||
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
|
|
||||||
|
|
||||||
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
|
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
|
||||||
LOGe("llama_decode() returned null");
|
|
||||||
}
|
|
||||||
|
|
||||||
return new_token;
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
|
||||||
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
|
|
||||||
}
|
|
||||||
|
|
@ -1,180 +0,0 @@
|
||||||
package android.llama.cpp
|
|
||||||
|
|
||||||
import android.util.Log
|
|
||||||
import kotlinx.coroutines.CoroutineDispatcher
|
|
||||||
import kotlinx.coroutines.asCoroutineDispatcher
|
|
||||||
import kotlinx.coroutines.flow.Flow
|
|
||||||
import kotlinx.coroutines.flow.flow
|
|
||||||
import kotlinx.coroutines.flow.flowOn
|
|
||||||
import kotlinx.coroutines.withContext
|
|
||||||
import java.util.concurrent.Executors
|
|
||||||
import kotlin.concurrent.thread
|
|
||||||
|
|
||||||
class LLamaAndroid {
|
|
||||||
private val tag: String? = this::class.simpleName
|
|
||||||
|
|
||||||
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
|
|
||||||
|
|
||||||
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
|
|
||||||
thread(start = false, name = "Llm-RunLoop") {
|
|
||||||
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
|
|
||||||
|
|
||||||
// No-op if called more than once.
|
|
||||||
System.loadLibrary("llama-android")
|
|
||||||
|
|
||||||
// Set llama log handler to Android
|
|
||||||
log_to_android()
|
|
||||||
backend_init(false)
|
|
||||||
|
|
||||||
Log.d(tag, system_info())
|
|
||||||
|
|
||||||
it.run()
|
|
||||||
}.apply {
|
|
||||||
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
|
|
||||||
Log.e(tag, "Unhandled exception", exception)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}.asCoroutineDispatcher()
|
|
||||||
|
|
||||||
private val nlen: Int = 64
|
|
||||||
|
|
||||||
private external fun log_to_android()
|
|
||||||
private external fun load_model(filename: String): Long
|
|
||||||
private external fun free_model(model: Long)
|
|
||||||
private external fun new_context(model: Long): Long
|
|
||||||
private external fun free_context(context: Long)
|
|
||||||
private external fun backend_init(numa: Boolean)
|
|
||||||
private external fun backend_free()
|
|
||||||
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
|
||||||
private external fun free_batch(batch: Long)
|
|
||||||
private external fun new_sampler(): Long
|
|
||||||
private external fun free_sampler(sampler: Long)
|
|
||||||
private external fun bench_model(
|
|
||||||
context: Long,
|
|
||||||
model: Long,
|
|
||||||
batch: Long,
|
|
||||||
pp: Int,
|
|
||||||
tg: Int,
|
|
||||||
pl: Int,
|
|
||||||
nr: Int
|
|
||||||
): String
|
|
||||||
|
|
||||||
private external fun system_info(): String
|
|
||||||
|
|
||||||
private external fun completion_init(
|
|
||||||
context: Long,
|
|
||||||
batch: Long,
|
|
||||||
text: String,
|
|
||||||
formatChat: Boolean,
|
|
||||||
nLen: Int
|
|
||||||
): Int
|
|
||||||
|
|
||||||
private external fun completion_loop(
|
|
||||||
context: Long,
|
|
||||||
batch: Long,
|
|
||||||
sampler: Long,
|
|
||||||
nLen: Int,
|
|
||||||
ncur: IntVar
|
|
||||||
): String?
|
|
||||||
|
|
||||||
private external fun kv_cache_clear(context: Long)
|
|
||||||
|
|
||||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
|
||||||
return withContext(runLoop) {
|
|
||||||
when (val state = threadLocalState.get()) {
|
|
||||||
is State.Loaded -> {
|
|
||||||
Log.d(tag, "bench(): $state")
|
|
||||||
bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> throw IllegalStateException("No model loaded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
suspend fun load(pathToModel: String) {
|
|
||||||
withContext(runLoop) {
|
|
||||||
when (threadLocalState.get()) {
|
|
||||||
is State.Idle -> {
|
|
||||||
val model = load_model(pathToModel)
|
|
||||||
if (model == 0L) throw IllegalStateException("load_model() failed")
|
|
||||||
|
|
||||||
val context = new_context(model)
|
|
||||||
if (context == 0L) throw IllegalStateException("new_context() failed")
|
|
||||||
|
|
||||||
val batch = new_batch(512, 0, 1)
|
|
||||||
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
|
||||||
|
|
||||||
val sampler = new_sampler()
|
|
||||||
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
|
|
||||||
|
|
||||||
Log.i(tag, "Loaded model $pathToModel")
|
|
||||||
threadLocalState.set(State.Loaded(model, context, batch, sampler))
|
|
||||||
}
|
|
||||||
else -> throw IllegalStateException("Model already loaded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
|
|
||||||
when (val state = threadLocalState.get()) {
|
|
||||||
is State.Loaded -> {
|
|
||||||
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
|
|
||||||
while (ncur.value <= nlen) {
|
|
||||||
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
|
|
||||||
if (str == null) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
emit(str)
|
|
||||||
}
|
|
||||||
kv_cache_clear(state.context)
|
|
||||||
}
|
|
||||||
else -> {}
|
|
||||||
}
|
|
||||||
}.flowOn(runLoop)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unloads the model and frees resources.
|
|
||||||
*
|
|
||||||
* This is a no-op if there's no model loaded.
|
|
||||||
*/
|
|
||||||
suspend fun unload() {
|
|
||||||
withContext(runLoop) {
|
|
||||||
when (val state = threadLocalState.get()) {
|
|
||||||
is State.Loaded -> {
|
|
||||||
free_context(state.context)
|
|
||||||
free_model(state.model)
|
|
||||||
free_batch(state.batch)
|
|
||||||
free_sampler(state.sampler);
|
|
||||||
|
|
||||||
threadLocalState.set(State.Idle)
|
|
||||||
}
|
|
||||||
else -> {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
private class IntVar(value: Int) {
|
|
||||||
@Volatile
|
|
||||||
var value: Int = value
|
|
||||||
private set
|
|
||||||
|
|
||||||
fun inc() {
|
|
||||||
synchronized(this) {
|
|
||||||
value += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private sealed interface State {
|
|
||||||
data object Idle: State
|
|
||||||
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce only one instance of Llm.
|
|
||||||
private val _instance: LLamaAndroid = LLamaAndroid()
|
|
||||||
|
|
||||||
fun instance(): LLamaAndroid = _instance
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -8,11 +8,11 @@ pluginManagement {
|
||||||
dependencyResolutionManagement {
|
dependencyResolutionManagement {
|
||||||
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||||
repositories {
|
repositories {
|
||||||
google()
|
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
google()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rootProject.name = "LlamaAndroid"
|
rootProject.name = "AiChat"
|
||||||
include(":app")
|
include(":app")
|
||||||
include(":llama")
|
include(":lib")
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import os
|
||||||
import importlib
|
import importlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -116,11 +116,11 @@ def debug_hook(name):
|
||||||
def fn(_m, input, output):
|
def fn(_m, input, output):
|
||||||
if isinstance(input, torch.Tensor):
|
if isinstance(input, torch.Tensor):
|
||||||
summarize(input, name + "_in")
|
summarize(input, name + "_in")
|
||||||
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
|
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
|
||||||
summarize(input[0], name + "_in")
|
summarize(input[0], name + "_in")
|
||||||
if isinstance(output, torch.Tensor):
|
if isinstance(output, torch.Tensor):
|
||||||
summarize(output, name + "_out")
|
summarize(output, name + "_out")
|
||||||
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
|
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
|
||||||
summarize(output[0], name + "_out")
|
summarize(output[0], name + "_out")
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
@ -130,6 +130,7 @@ unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Process model with specified path")
|
parser = argparse.ArgumentParser(description="Process model with specified path")
|
||||||
parser.add_argument("--model-path", "-m", help="Path to the model")
|
parser.add_argument("--model-path", "-m", help="Path to the model")
|
||||||
|
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_path = os.environ.get("MODEL_PATH", args.model_path)
|
model_path = os.environ.get("MODEL_PATH", args.model_path)
|
||||||
|
|
@ -142,8 +143,13 @@ if model_path is None:
|
||||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
multimodal = False
|
||||||
|
full_config = config
|
||||||
|
|
||||||
print("Model type: ", config.model_type)
|
print("Model type: ", config.model_type)
|
||||||
|
if "vocab_size" not in config and "text_config" in config:
|
||||||
|
config = config.text_config
|
||||||
|
multimodal = True
|
||||||
print("Vocab size: ", config.vocab_size)
|
print("Vocab size: ", config.vocab_size)
|
||||||
print("Hidden size: ", config.hidden_size)
|
print("Hidden size: ", config.hidden_size)
|
||||||
print("Number of layers: ", config.num_hidden_layers)
|
print("Number of layers: ", config.num_hidden_layers)
|
||||||
|
|
@ -169,9 +175,14 @@ if unreleased_model_name:
|
||||||
print(f"Failed to import or load model: {e}")
|
print(f"Failed to import or load model: {e}")
|
||||||
exit(1)
|
exit(1)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if multimodal:
|
||||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
|
model = AutoModelForImageTextToText.from_pretrained(
|
||||||
)
|
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
|
||||||
|
)
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if len(list(module.children())) == 0: # only leaf modules
|
if len(list(module.children())) == 0: # only leaf modules
|
||||||
|
|
@ -185,7 +196,10 @@ model_name = os.path.basename(model_path)
|
||||||
print(f"Model class: {model.__class__.__name__}")
|
print(f"Model class: {model.__class__.__name__}")
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
if os.getenv("MODEL_TESTING_PROMPT"):
|
if args.prompt_file:
|
||||||
|
with open(args.prompt_file, encoding='utf-8') as f:
|
||||||
|
prompt = f.read()
|
||||||
|
elif os.getenv("MODEL_TESTING_PROMPT"):
|
||||||
prompt = os.getenv("MODEL_TESTING_PROMPT")
|
prompt = os.getenv("MODEL_TESTING_PROMPT")
|
||||||
else:
|
else:
|
||||||
prompt = "Hello, my name is"
|
prompt = "Hello, my name is"
|
||||||
|
|
@ -195,9 +209,18 @@ print(f"Input tokens: {input_ids}")
|
||||||
print(f"Input text: {repr(prompt)}")
|
print(f"Input text: {repr(prompt)}")
|
||||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
||||||
|
|
||||||
|
batch_size = 512
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(input_ids.to(model.device))
|
past = None
|
||||||
logits = outputs.logits
|
outputs = None
|
||||||
|
for i in range(0, input_ids.size(1), batch_size):
|
||||||
|
print(f"Processing chunk with tokens {i} to {i + batch_size}")
|
||||||
|
chunk = input_ids[:, i:i + batch_size]
|
||||||
|
outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
|
||||||
|
past = outputs.past_key_values
|
||||||
|
|
||||||
|
logits = outputs.logits # type: ignore
|
||||||
|
|
||||||
# Extract logits for the last token (next token prediction)
|
# Extract logits for the last token (next token prediction)
|
||||||
last_logits = logits[0, -1, :].float().cpu().numpy()
|
last_logits = logits[0, -1, :].float().cpu().numpy()
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
|
||||||
bool accept = false;
|
bool accept = false;
|
||||||
if (params.sampling.temp > 0) {
|
if (params.sampling.temp > 0) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
||||||
|
|
||||||
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
|
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
|
||||||
|
|
||||||
|
|
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
|
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
||||||
|
|
||||||
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
|
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -386,6 +386,9 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
|
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
|
||||||
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||||
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
|
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
|
||||||
elseif (APPLE)
|
elseif (APPLE)
|
||||||
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
|
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
|
||||||
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
|
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,8 @@
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
|
@ -51,6 +53,8 @@
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||||
|
|
@ -67,10 +71,14 @@
|
||||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#elif defined(__POWERPC__) || defined(__powerpc__)
|
#elif defined(__POWERPC__) || defined(__powerpc__)
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
|
||||||
// quants.c
|
// quants.c
|
||||||
|
|
@ -91,6 +99,8 @@
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
|
@ -99,6 +109,8 @@
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#elif defined(__loongarch64)
|
#elif defined(__loongarch64)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||||
|
|
@ -119,6 +131,8 @@
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
|
@ -127,6 +141,8 @@
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#elif defined(__riscv)
|
#elif defined(__riscv)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||||
|
|
@ -154,6 +170,8 @@
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||||
|
|
@ -161,6 +179,8 @@
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#elif defined(__s390x__)
|
#elif defined(__s390x__)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||||
|
|
@ -187,6 +207,8 @@
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
|
@ -195,6 +217,8 @@
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#elif defined(__wasm__)
|
#elif defined(__wasm__)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||||
|
|
@ -223,6 +247,8 @@
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
|
|
@ -231,4 +257,6 @@
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||||
|
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -786,6 +786,133 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q8_0_4x4_q8_0(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 4;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
||||||
|
|
||||||
|
for (int c = 0; c < nc; c += ncols_interleaved) {
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
float32x4_t acc = vdupq_n_f32(0);
|
||||||
|
for (int b = 0; b < nb; b++) {
|
||||||
|
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
||||||
|
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
||||||
|
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
||||||
|
|
||||||
|
int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
|
||||||
|
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
||||||
|
|
||||||
|
int32x4_t ret = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
|
||||||
|
ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
|
||||||
|
ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
|
||||||
|
ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
|
||||||
|
|
||||||
|
ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
|
||||||
|
ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
|
||||||
|
ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
|
||||||
|
ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
|
||||||
|
|
||||||
|
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
||||||
|
a_ptr++;
|
||||||
|
b_ptr++;
|
||||||
|
}
|
||||||
|
vst1q_f32(s, acc);
|
||||||
|
s += ncols_interleaved;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
|
||||||
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q8_0_4x8_q8_0(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
||||||
|
|
||||||
|
for (int c = 0; c < nc; c += ncols_interleaved) {
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
float32x4_t acc = vdupq_n_f32(0);
|
||||||
|
|
||||||
|
for (int b = 0; b < nb; b++) {
|
||||||
|
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
||||||
|
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
||||||
|
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
||||||
|
|
||||||
|
int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
|
||||||
|
int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
|
||||||
|
int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
|
||||||
|
int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
|
||||||
|
int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
|
||||||
|
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
||||||
|
|
||||||
|
int32x4_t ret0 = vdupq_n_s32(0);
|
||||||
|
int32x4_t ret1 = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
// 0..7
|
||||||
|
ret0 = vdotq_s32(ret0, b_low.val[0], a0);
|
||||||
|
ret1 = vdotq_s32(ret1, b_low.val[1], a0);
|
||||||
|
// 8..15
|
||||||
|
ret0 = vdotq_s32(ret0, b_low.val[2], a1);
|
||||||
|
ret1 = vdotq_s32(ret1, b_low.val[3], a1);
|
||||||
|
// 16..23
|
||||||
|
ret0 = vdotq_s32(ret0, b_high.val[0], a2);
|
||||||
|
ret1 = vdotq_s32(ret1, b_high.val[1], a2);
|
||||||
|
// 24..31
|
||||||
|
ret0 = vdotq_s32(ret0, b_high.val[2], a3);
|
||||||
|
ret1 = vdotq_s32(ret1, b_high.val[3], a3);
|
||||||
|
|
||||||
|
int32x4_t ret = vpaddq_s32(ret0, ret1);
|
||||||
|
|
||||||
|
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
||||||
|
a_ptr++;
|
||||||
|
b_ptr++;
|
||||||
|
}
|
||||||
|
vst1q_f32(s, acc);
|
||||||
|
s += ncols_interleaved;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
|
||||||
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
@ -2610,3 +2737,159 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_gemm_q8_0_4x4_q8_0(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 4;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
||||||
|
|
||||||
|
float32x4_t sumf[4];
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
sumf[m] = vdupq_n_f32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
|
||||||
|
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
|
||||||
|
|
||||||
|
int32x4_t sumi_0 = vdupq_n_s32(0);
|
||||||
|
int32x4_t sumi_1 = vdupq_n_s32(0);
|
||||||
|
int32x4_t sumi_2 = vdupq_n_s32(0);
|
||||||
|
int32x4_t sumi_3 = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
for (int k_group = 0; k_group < 8; k_group += 4) {
|
||||||
|
int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
|
||||||
|
int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
|
||||||
|
|
||||||
|
for (int k = 0; k < 4; k++) {
|
||||||
|
sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
|
||||||
|
sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
|
||||||
|
sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
|
||||||
|
sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
||||||
|
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
||||||
|
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
||||||
|
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_q8_0_4x8_q8_0(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr; y += 4) {
|
||||||
|
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
|
||||||
|
|
||||||
|
for (int x = 0; x < nc; x += ncols_interleaved) {
|
||||||
|
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
|
||||||
|
const block_q8_0x4 * a_ptr = a_ptr_base;
|
||||||
|
|
||||||
|
float32x4_t acc_f32[4];
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
acc_f32[i] = vdupq_n_f32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nb; b++) {
|
||||||
|
int32x4_t acc[4];
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
acc[i] = vdupq_n_s32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process 4 chunks of 8 positions each
|
||||||
|
for (int chunk = 0; chunk < 4; chunk++) {
|
||||||
|
int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
|
||||||
|
int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
|
||||||
|
int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
|
||||||
|
int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
|
||||||
|
|
||||||
|
acc[0] = vmmlaq_s32(acc[0], a01, b01);
|
||||||
|
acc[1] = vmmlaq_s32(acc[1], a01, b23);
|
||||||
|
acc[2] = vmmlaq_s32(acc[2], a23, b01);
|
||||||
|
acc[3] = vmmlaq_s32(acc[3], a23, b23);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reorder outputs from 2×2 tiles to row-major
|
||||||
|
// acc[0] = [r0c0, r0c1, r1c0, r1c1]
|
||||||
|
// acc[1] = [r0c2, r0c3, r1c2, r1c3]
|
||||||
|
// acc[2] = [r2c0, r2c1, r3c0, r3c1]
|
||||||
|
// acc[3] = [r2c2, r2c3, r3c2, r3c3]
|
||||||
|
int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
|
||||||
|
int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
|
||||||
|
int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
|
||||||
|
int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
|
||||||
|
|
||||||
|
// Scales
|
||||||
|
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
|
||||||
|
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
|
||||||
|
|
||||||
|
acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
|
||||||
|
acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
|
||||||
|
acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
|
||||||
|
acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
|
||||||
|
|
||||||
|
a_ptr++;
|
||||||
|
b_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < 4; row++) {
|
||||||
|
vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -692,6 +692,100 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 4;
|
||||||
|
|
||||||
|
assert(nr == 1);
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(nr);
|
||||||
|
|
||||||
|
float sumf[4];
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[j] = 0.0;
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / blocklen); k++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
|
||||||
|
sumi += v0 * a_ptr[l].qs[k * blocklen + i];
|
||||||
|
}
|
||||||
|
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[x * ncols_interleaved + j] = sumf[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert(nr == 1);
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(nr);
|
||||||
|
|
||||||
|
float sumf[4];
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[j] = 0.0;
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / blocklen); k++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
|
||||||
|
sumi += v0 * a_ptr[l].qs[k * blocklen + i];
|
||||||
|
}
|
||||||
|
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[x * ncols_interleaved + j] = sumf[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
@ -1219,8 +1313,129 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 4;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
float sumf[4][4];
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[m][j] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / blocklen); k++) {
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
|
||||||
|
sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
|
||||||
|
}
|
||||||
|
sumf[m][j] +=
|
||||||
|
sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_q8_0_4x8_q8_0_generic(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 4;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
float sumf[4][4];
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumf[m][j] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / blocklen); k++) {
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
|
||||||
|
sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
|
||||||
|
}
|
||||||
|
sumf[m][j] +=
|
||||||
|
sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
|
static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
|
||||||
|
block_q8_0x4 out;
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
out.d[i] = in[i].d;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int end = QK8_0 * 4 / blck_size_interleave;
|
||||||
|
for (int i = 0; i < end; ++i) {
|
||||||
|
int src_id = i % 4;
|
||||||
|
int src_offset = (i / 4) * blck_size_interleave;
|
||||||
|
int dst_offset = i * blck_size_interleave;
|
||||||
|
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
||||||
block_q4_0x4 out;
|
block_q4_0x4 out;
|
||||||
|
|
||||||
|
|
@ -1534,6 +1749,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block
|
||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t,
|
||||||
|
int interleave_block,
|
||||||
|
const void * GGML_RESTRICT data,
|
||||||
|
size_t data_size) {
|
||||||
|
GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
|
||||||
|
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||||
|
constexpr int nrows_interleaved = 4;
|
||||||
|
|
||||||
|
block_q8_0x4 * dst = (block_q8_0x4 *) t->data;
|
||||||
|
const block_q8_0 * src = (const block_q8_0 *) data;
|
||||||
|
block_q8_0 dst_tmp[4];
|
||||||
|
int nrow = ggml_nrows(t);
|
||||||
|
int nblocks = t->ne[0] / QK8_0;
|
||||||
|
|
||||||
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
|
||||||
|
|
||||||
|
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||||
|
for (int64_t x = 0; x < nblocks; x++) {
|
||||||
|
for (int i = 0; i < nrows_interleaved; i++) {
|
||||||
|
dst_tmp[i] = src[x + i * nblocks];
|
||||||
|
}
|
||||||
|
*dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
|
||||||
|
}
|
||||||
|
src += nrows_interleaved * nblocks;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
|
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
|
||||||
block_iq4_nlx4 out;
|
block_iq4_nlx4 out;
|
||||||
|
|
||||||
|
|
@ -1702,6 +1949,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
|
||||||
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
|
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
|
return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
|
return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
|
||||||
|
}
|
||||||
|
|
||||||
// gemv
|
// gemv
|
||||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
||||||
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
||||||
|
|
@ -1738,6 +1993,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||||
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
// gemm
|
// gemm
|
||||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
||||||
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
||||||
|
|
@ -1774,6 +2037,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||||
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
||||||
public:
|
public:
|
||||||
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
||||||
|
|
@ -2168,6 +2439,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
|
||||||
|
|
||||||
|
// instance for Q8_0
|
||||||
|
static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
|
||||||
|
static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
|
||||||
|
|
||||||
if (cur->type == GGML_TYPE_Q4_0) {
|
if (cur->type == GGML_TYPE_Q4_0) {
|
||||||
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
|
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
|
||||||
|| (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
|
|| (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
|
||||||
|
|
@ -2218,6 +2493,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
return &iq4_nl_4x4_q8_0;
|
return &iq4_nl_4x4_q8_0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (cur->type == GGML_TYPE_Q8_0) {
|
||||||
|
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||||
|
if (cur->ne[1] % 4 == 0) {
|
||||||
|
return &q8_0_4x8_q8_0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||||
|
if (cur->ne[1] % 4 == 0) {
|
||||||
|
return &q8_0_4x4_q8_0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
// Native implementations
|
// Native implementations
|
||||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
|
|
@ -120,6 +124,10 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
#if defined(__cplusplus)
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
|
||||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||||
if (val > maxval) {
|
if (val > maxval) {
|
||||||
|
|
@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
|
||||||
argmax = shared_argmax[lane_id];
|
argmax = shared_argmax[lane_id];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
|
||||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||||
if (val > maxval) {
|
if (val > maxval) {
|
||||||
|
|
|
||||||
|
|
@ -76,15 +76,31 @@ namespace ggml_cuda_mma {
|
||||||
// For the A/C matrices this means I major == row major, J major == column major.
|
// For the A/C matrices this means I major == row major, J major == column major.
|
||||||
// For the B matrix this means I major == column major, J major == row major.
|
// For the B matrix this means I major == column major, J major == row major.
|
||||||
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
||||||
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
|
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
||||||
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
|
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
||||||
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
|
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
|
||||||
|
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
||||||
|
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
|
||||||
};
|
};
|
||||||
// Implemented mma combinations are:
|
// Implemented mma combinations are:
|
||||||
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
||||||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
|
|
||||||
|
constexpr bool is_i_major(const data_layout dl) {
|
||||||
|
return dl == DATA_LAYOUT_I_MAJOR ||
|
||||||
|
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
|
||||||
|
dl == DATA_LAYOUT_I_MAJOR_DUAL;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr data_layout get_input_data_layout() {
|
||||||
|
#if defined(RDNA3)
|
||||||
|
return DATA_LAYOUT_I_MAJOR_DUAL;
|
||||||
|
#else
|
||||||
|
return DATA_LAYOUT_I_MAJOR;
|
||||||
|
#endif // defined(RDNA3)
|
||||||
|
}
|
||||||
|
|
||||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||||
struct tile {};
|
struct tile {};
|
||||||
|
|
||||||
|
|
@ -115,9 +131,9 @@ namespace ggml_cuda_mma {
|
||||||
} else if constexpr (I == 32 && J == 4) {
|
} else if constexpr (I == 32 && J == 4) {
|
||||||
return threadIdx.x % 32;
|
return threadIdx.x % 32;
|
||||||
} else if constexpr (I == 16 && J == 16) {
|
} else if constexpr (I == 16 && J == 16) {
|
||||||
return 4 * (threadIdx.x / 16) + l;
|
return threadIdx.x % 16;
|
||||||
} else if constexpr (I == 32 && J == 32) {
|
} else if constexpr (I == 32 && J == 32) {
|
||||||
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
return threadIdx.x % 32;
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -132,9 +148,9 @@ namespace ggml_cuda_mma {
|
||||||
} else if constexpr (I == 32 && J == 4) {
|
} else if constexpr (I == 32 && J == 4) {
|
||||||
return 2 * (threadIdx.x / 32) + l;
|
return 2 * (threadIdx.x / 32) + l;
|
||||||
} else if constexpr (I == 16 && J == 16) {
|
} else if constexpr (I == 16 && J == 16) {
|
||||||
return threadIdx.x % 16;
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
} else if constexpr (I == 32 && J == 32) {
|
} else if constexpr (I == 32 && J == 32) {
|
||||||
return threadIdx.x % 32;
|
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -171,28 +187,19 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(RDNA4)
|
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
#elif defined(RDNA3)
|
|
||||||
static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
static constexpr __device__ bool supported() {
|
||||||
if (I == 16 && J == 16) return true;
|
if (I == 16 && J == 16) return true;
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
if (I == 16 && J == 4) return true;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
if constexpr (I == 16 && J == 16) {
|
if constexpr (supported()) {
|
||||||
#if defined(RDNA4)
|
return threadIdx.x % 16;
|
||||||
return 8 * (threadIdx.x / 16) + l;
|
|
||||||
#elif defined(RDNA3)
|
|
||||||
return 2 * l + (threadIdx.x / 16);
|
|
||||||
#else
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -201,7 +208,17 @@ namespace ggml_cuda_mma {
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
if constexpr (I == 16 && J == 16) {
|
if constexpr (I == 16 && J == 16) {
|
||||||
return threadIdx.x % 16;
|
// matrix C
|
||||||
|
#if defined(RDNA3)
|
||||||
|
return 2 * l + (threadIdx.x / 16);
|
||||||
|
#else
|
||||||
|
return ne * (threadIdx.x / 16) + l;
|
||||||
|
#endif // defined(RDNA3)
|
||||||
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
|
// mmq input for RDNA4
|
||||||
|
return ne * (threadIdx.x / 16) + l;
|
||||||
|
} else if constexpr (I == 16 && J == 4) {
|
||||||
|
return ne * (threadIdx.x / 16) + l;
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -293,12 +310,7 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(RDNA3)
|
|
||||||
// RDNA3 has duplicated data as input.
|
|
||||||
static constexpr int ne = I * J / 32 * 2;
|
|
||||||
#else
|
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
#endif // defined(RDNA3)
|
|
||||||
half2 x[ne] = {{0.0f, 0.0f}};
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
static constexpr __device__ bool supported() {
|
||||||
|
|
@ -317,14 +329,7 @@ namespace ggml_cuda_mma {
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
if constexpr (I == 16 && J == 8) {
|
if constexpr (I == 16 && J == 8) {
|
||||||
#if defined(RDNA4)
|
|
||||||
return 4 * (threadIdx.x / 16) + l;
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
#elif defined(RDNA3)
|
|
||||||
return l;
|
|
||||||
#else
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -382,42 +387,19 @@ namespace ggml_cuda_mma {
|
||||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||||
|
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(RDNA3)
|
|
||||||
// RDNA3 has duplicated data as input.
|
|
||||||
static constexpr int ne = I * J / 32 * 2;
|
|
||||||
#else
|
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
#endif // defined(RDNA3)
|
|
||||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
static constexpr __device__ bool supported() {
|
||||||
if (I == 16 && J == 8) return true;
|
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
if constexpr (I == 16 && J == 8) {
|
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||||
return threadIdx.x % 16;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
if constexpr (I == 16 && J == 8) {
|
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
||||||
#if defined(RDNA4)
|
|
||||||
return 4 * (threadIdx.x / 16) + l;
|
|
||||||
#elif defined(RDNA3)
|
|
||||||
return l;
|
|
||||||
#else
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
|
|
@ -458,6 +440,28 @@ namespace ggml_cuda_mma {
|
||||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int I_, int J_, typename T>
|
||||||
|
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
|
||||||
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
|
||||||
|
|
||||||
|
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
|
|
@ -524,6 +528,42 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int I_, int J_, typename T>
|
||||||
|
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
|
||||||
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||||
|
|
||||||
|
static constexpr int ne = I * J / 32 * 2;
|
||||||
|
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 16) return true;
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
if (I == 16 && J == 4) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (supported()) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (supported()) {
|
||||||
|
return l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(TURING_MMA_AVAILABLE)
|
#if defined(TURING_MMA_AVAILABLE)
|
||||||
template <int I, int J>
|
template <int I, int J>
|
||||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||||
|
|
@ -569,55 +609,28 @@ namespace ggml_cuda_mma {
|
||||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int64_t * xi = (int64_t *) t.x;
|
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
|
||||||
xi[0] = xs[0];
|
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
// All wmma layout has contiguous data when i-major.
|
||||||
#if defined(RDNA4)
|
if constexpr (is_i_major(dl)) {
|
||||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
|
||||||
#elif defined(RDNA3)
|
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
|
||||||
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
|
||||||
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
|
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
|
||||||
#else
|
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
|
||||||
NO_DEVICE_CODE;
|
#pragma unroll
|
||||||
#endif // defined(RDNA4)
|
for (int i = 0; i < aligned_copy_count; ++i) {
|
||||||
} else if constexpr (std::is_same_v<T, int>) {
|
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
|
||||||
if constexpr (I == 16 && J == 4) {
|
}
|
||||||
int64_t * xi = (int64_t *) t.x;
|
|
||||||
#if defined(RDNA4)
|
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
|
||||||
xi[0] = xs[0];
|
|
||||||
#elif defined(RDNA3)
|
|
||||||
static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
|
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
|
|
||||||
xi[0] = xs[0];
|
|
||||||
xi[1] = xs[1];
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
} else if constexpr (I == 16 && J == 8) {
|
|
||||||
int64_t * xi = (int64_t *) t.x;
|
|
||||||
#if defined(RDNA4)
|
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
|
||||||
xi[0] = xs[0];
|
|
||||||
|
|
||||||
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
|
||||||
xi[1] = xs1[0];
|
|
||||||
#elif defined(RDNA3)
|
|
||||||
static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
|
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
|
|
||||||
// contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
|
|
||||||
xi[0] = xs[0];
|
|
||||||
xi[1] = xs[1];
|
|
||||||
const int64_t * xs1 = xs + 2;
|
|
||||||
xi[2] = xs1[0];
|
|
||||||
xi[3] = xs1[1];
|
|
||||||
#endif // defined(RDNA4)
|
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
#pragma unroll
|
||||||
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -660,9 +673,9 @@ namespace ggml_cuda_mma {
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, data_layout dl>
|
||||||
static __device__ __forceinline__ void load_ldmatrix(
|
static __device__ __forceinline__ void load_ldmatrix(
|
||||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||||
#if defined(TURING_MMA_AVAILABLE)
|
#if defined(TURING_MMA_AVAILABLE)
|
||||||
int * xi = (int * ) t.x;
|
int * xi = (int * ) t.x;
|
||||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||||
|
|
@ -832,8 +845,9 @@ namespace ggml_cuda_mma {
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <data_layout dl_ab, data_layout dl_d>
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
|
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
|
||||||
#ifdef AMPERE_MMA_AVAILABLE
|
#ifdef AMPERE_MMA_AVAILABLE
|
||||||
const int * Axi = (const int *) A.x;
|
const int * Axi = (const int *) A.x;
|
||||||
const int * Bxi = (const int *) B.x;
|
const int * Bxi = (const int *) B.x;
|
||||||
|
|
@ -887,8 +901,9 @@ namespace ggml_cuda_mma {
|
||||||
#endif // AMPERE_MMA_AVAILABLE
|
#endif // AMPERE_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <data_layout dl_ab, data_layout dl_d>
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
|
||||||
#ifdef TURING_MMA_AVAILABLE
|
#ifdef TURING_MMA_AVAILABLE
|
||||||
const int * Axi = (const int *) A.x;
|
const int * Axi = (const int *) A.x;
|
||||||
const int * Bxi = (const int *) B.x;
|
const int * Bxi = (const int *) B.x;
|
||||||
|
|
@ -940,8 +955,9 @@ namespace ggml_cuda_mma {
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <data_layout dl_ab, data_layout dl_d>
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
|
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(RDNA4)
|
#if defined(RDNA4)
|
||||||
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
||||||
|
|
@ -967,8 +983,9 @@ namespace ggml_cuda_mma {
|
||||||
#endif // AMPERE_MMA_AVAILABLE
|
#endif // AMPERE_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <data_layout dl_d, data_layout dl_ab>
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||||
int32x4_t * acc = (int32x4_t *) D.x;
|
int32x4_t * acc = (int32x4_t *) D.x;
|
||||||
|
|
@ -1122,8 +1139,9 @@ namespace ggml_cuda_mma {
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void mma(
|
template <data_layout dl_d, data_layout dl_ab>
|
||||||
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||||
int32x8_t * acc = (int32x8_t *) D.x;
|
int32x8_t * acc = (int32x8_t *) D.x;
|
||||||
|
|
|
||||||
|
|
@ -32,11 +32,13 @@ static __global__ void mul_mat_f(
|
||||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||||
typedef tile<16, 8, T> tile_A;
|
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||||
typedef tile<tile_B_I, 8, T> tile_B;
|
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||||
typedef tile<16, tile_C_J, float> tile_C;
|
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||||
|
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||||
|
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
#else
|
#else
|
||||||
#ifdef VOLTA_MMA_AVAILABLE
|
#ifdef VOLTA_MMA_AVAILABLE
|
||||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||||
|
|
@ -272,11 +274,13 @@ static __global__ void mul_mat_f_ids(
|
||||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||||
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||||
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||||
typedef tile<16, 8, T> tile_A;
|
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||||
typedef tile<tile_B_I, 8, T> tile_B;
|
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||||
typedef tile<16, tile_C_J, float> tile_C;
|
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||||
|
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||||
|
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
#else
|
#else
|
||||||
#ifdef VOLTA_MMA_AVAILABLE
|
#ifdef VOLTA_MMA_AVAILABLE
|
||||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||||
|
|
|
||||||
|
|
@ -797,9 +797,10 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -966,9 +967,10 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -1130,10 +1132,11 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<64, 2, int> tile_load;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
typedef tile<64, 2, int, input_layout> tile_load;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -1179,9 +1182,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||||
typedef tile<16, 4, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 4, int> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -1435,10 +1439,11 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<64, 2, int> tile_load;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
typedef tile<64, 2, int, input_layout> tile_load;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -1501,10 +1506,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||||
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 4, int> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -2265,10 +2270,11 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<64, 2, int> tile_load;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
typedef tile<64, 2, int, input_layout> tile_load;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -2316,9 +2322,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||||
typedef tile<16, 4, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 4, int> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -3015,7 +3022,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
||||||
|
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
||||||
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
|
typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
#else
|
#else
|
||||||
typedef tile<16, 8, int> tile_C;
|
typedef tile<16, 8, int> tile_C;
|
||||||
|
|
|
||||||
|
|
@ -288,7 +288,7 @@ class LocalTensor:
|
||||||
data_range: LocalTensorRange
|
data_range: LocalTensorRange
|
||||||
|
|
||||||
def mmap_bytes(self) -> np.ndarray:
|
def mmap_bytes(self) -> np.ndarray:
|
||||||
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
|
return np.memmap(self.data_range.filename, mode='r', offset=self.data_range.offset, shape=self.data_range.size)
|
||||||
|
|
||||||
|
|
||||||
class SafetensorsLocal:
|
class SafetensorsLocal:
|
||||||
|
|
|
||||||
|
|
@ -2055,7 +2055,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||||
LLM_TENSOR_TOKEN_EMBD,
|
LLM_TENSOR_TOKEN_EMBD,
|
||||||
LLM_TENSOR_OUTPUT_NORM,
|
LLM_TENSOR_OUTPUT_NORM_LFM2,
|
||||||
LLM_TENSOR_FFN_GATE_INP,
|
LLM_TENSOR_FFN_GATE_INP,
|
||||||
LLM_TENSOR_FFN_GATE_EXPS,
|
LLM_TENSOR_FFN_GATE_EXPS,
|
||||||
LLM_TENSOR_FFN_DOWN_EXPS,
|
LLM_TENSOR_FFN_DOWN_EXPS,
|
||||||
|
|
|
||||||
|
|
@ -362,23 +362,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
|
if (!smpl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (smpl->iface->accept) {
|
if (smpl->iface->accept) {
|
||||||
smpl->iface->accept(smpl, token);
|
smpl->iface->accept(smpl, token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
||||||
|
if (!smpl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(smpl->iface->apply);
|
GGML_ASSERT(smpl->iface->apply);
|
||||||
smpl->iface->apply(smpl, cur_p);
|
smpl->iface->apply(smpl, cur_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_reset(struct llama_sampler * smpl) {
|
void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||||
|
if (!smpl) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (smpl->iface->reset) {
|
if (smpl->iface->reset) {
|
||||||
smpl->iface->reset(smpl);
|
smpl->iface->reset(smpl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
||||||
|
if (!smpl) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
if (smpl->iface->clone) {
|
if (smpl->iface->clone) {
|
||||||
return smpl->iface->clone(smpl);
|
return smpl->iface->clone(smpl);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,8 +71,9 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
||||||
}, &ud);
|
}, &ud);
|
||||||
|
|
||||||
llama_model_params mparams_copy = *mparams;
|
llama_model_params mparams_copy = *mparams;
|
||||||
mparams_copy.no_alloc = true;
|
mparams_copy.no_alloc = true;
|
||||||
mparams_copy.use_mmap = false;
|
mparams_copy.use_mmap = false;
|
||||||
|
mparams_copy.use_mlock = false;
|
||||||
|
|
||||||
llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
|
llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
|
|
@ -180,11 +181,12 @@ static void llama_params_fit_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t sum_total = 0;
|
int64_t sum_total = 0;
|
||||||
int64_t sum_projected_free = 0;
|
int64_t sum_projected_free = 0;
|
||||||
int64_t min_projected_free = INT64_MAX;
|
int64_t min_projected_free = INT64_MAX;
|
||||||
int64_t sum_projected_used = 0;
|
int64_t sum_projected_used = 0;
|
||||||
int64_t sum_projected_ctx = 0;
|
int64_t sum_projected_model = 0;
|
||||||
|
int64_t sum_projected_ctx = 0;
|
||||||
|
|
||||||
if (nd > 1) {
|
if (nd > 1) {
|
||||||
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
|
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
|
||||||
|
|
@ -195,11 +197,12 @@ static void llama_params_fit_impl(
|
||||||
const int64_t projected_used = dmd.mb.total();
|
const int64_t projected_used = dmd.mb.total();
|
||||||
const int64_t projected_free = dmd.free - projected_used;
|
const int64_t projected_free = dmd.free - projected_used;
|
||||||
|
|
||||||
sum_total += dmd.total;
|
sum_total += dmd.total;
|
||||||
sum_projected_used += projected_used;
|
sum_projected_used += projected_used;
|
||||||
sum_projected_free += projected_free;
|
sum_projected_free += projected_free;
|
||||||
min_projected_free = std::min(min_projected_free, projected_free);
|
min_projected_free = std::min(min_projected_free, projected_free);
|
||||||
sum_projected_ctx += dmd.mb.context;
|
sum_projected_model += dmd.mb.model;
|
||||||
|
sum_projected_ctx += dmd.mb.context;
|
||||||
|
|
||||||
if (nd > 1) {
|
if (nd > 1) {
|
||||||
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
|
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
|
||||||
|
|
@ -234,10 +237,24 @@ static void llama_params_fit_impl(
|
||||||
if (cparams->n_ctx == 0) {
|
if (cparams->n_ctx == 0) {
|
||||||
if (hp_nct > n_ctx_min) {
|
if (hp_nct > n_ctx_min) {
|
||||||
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
|
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
|
||||||
const uint32_t ctx_reduction = std::min(
|
|
||||||
uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
|
int64_t memory_reduction = -global_surplus;
|
||||||
|
if (nd > 1) {
|
||||||
|
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
|
||||||
|
// - for dense models only whole layers can be assigned to devices
|
||||||
|
// - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer
|
||||||
|
// - on average we expect a waste of 0.5 layers/tensors per device
|
||||||
|
// - use slightly more than the expected average for nd devices to be safe
|
||||||
|
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
|
||||||
|
memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
|
||||||
cparams->n_ctx = hp_nct - ctx_reduction;
|
cparams->n_ctx = hp_nct - ctx_reduction;
|
||||||
const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
|
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
|
||||||
|
|
||||||
|
ctx_reduction = hp_nct - cparams->n_ctx;
|
||||||
|
memory_reduction = ctx_reduction * bytes_per_ctx;
|
||||||
global_surplus += memory_reduction;
|
global_surplus += memory_reduction;
|
||||||
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
||||||
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
||||||
|
|
@ -481,8 +498,13 @@ static void llama_params_fit_impl(
|
||||||
} else {
|
} else {
|
||||||
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
|
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
|
||||||
}
|
}
|
||||||
uint32_t n_unassigned = hp_ngl;
|
|
||||||
for (int id = nd - 1; id >= 0; id--) {
|
for (int id = nd - 1; id >= 0; id--) {
|
||||||
|
uint32_t n_unassigned = hp_ngl;
|
||||||
|
for (size_t jd = id + 1; jd < nd; ++jd) {
|
||||||
|
assert(n_unassigned >= ngl_per_device[jd].n_layer);
|
||||||
|
n_unassigned -= ngl_per_device[jd].n_layer;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||||
ngl_per_device_high[id].n_layer = n_unassigned;
|
ngl_per_device_high[id].n_layer = n_unassigned;
|
||||||
if (hp_nex > 0) {
|
if (hp_nex > 0) {
|
||||||
|
|
@ -491,7 +513,9 @@ static void llama_params_fit_impl(
|
||||||
if (ngl_per_device_high[id].n_layer > 0) {
|
if (ngl_per_device_high[id].n_layer > 0) {
|
||||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
||||||
if (mem_high[id] > targets[id]) {
|
if (mem_high[id] > targets[id]) {
|
||||||
|
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
||||||
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||||
|
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
|
||||||
while (delta > 1) {
|
while (delta > 1) {
|
||||||
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||||
step_size = std::max(step_size, uint32_t(1));
|
step_size = std::max(step_size, uint32_t(1));
|
||||||
|
|
@ -505,20 +529,19 @@ static void llama_params_fit_impl(
|
||||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||||
|
|
||||||
if (mem_test[id] <= targets[id]) {
|
if (mem_test[id] <= targets[id]) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
n_unassigned -= ngl_per_device[id].n_layer;
|
|
||||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||||
} else {
|
} else {
|
||||||
ngl_per_device_high = ngl_per_device_test;
|
ngl_per_device_high = ngl_per_device_test;
|
||||||
mem_high = mem_test;
|
mem_high = mem_test;
|
||||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer);
|
||||||
}
|
}
|
||||||
delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ngl_per_device = ngl_per_device_high;
|
assert(ngl_per_device_high[id].n_layer == n_unassigned);
|
||||||
n_unassigned -= ngl_per_device[id].n_layer;
|
ngl_per_device = ngl_per_device_high;
|
||||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,11 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <chrono>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
using namespace std::chrono_literals;
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
|
|
@ -22,13 +26,17 @@ int main(int argc, char ** argv) {
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
auto mparams = common_model_params_to_llama(params);
|
auto mparams = common_model_params_to_llama(params);
|
||||||
auto cparams = common_context_params_to_llama(params);
|
auto cparams = common_context_params_to_llama(params);
|
||||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
const bool success = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
LOG_INF("Printing fitted CLI arguments to stdout...\n");
|
LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__);
|
||||||
std::cout << "-c " << cparams.n_ctx;
|
std::this_thread::sleep_for(10ms); // to avoid a race between stderr and stdout
|
||||||
std::cout << " -ngl " << mparams.n_gpu_layers;
|
printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers);
|
||||||
|
|
||||||
size_t nd = llama_max_devices();
|
size_t nd = llama_max_devices();
|
||||||
while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) {
|
while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) {
|
||||||
|
|
@ -37,26 +45,22 @@ int main(int argc, char ** argv) {
|
||||||
if (nd > 1) {
|
if (nd > 1) {
|
||||||
for (size_t id = 0; id < nd; id++) {
|
for (size_t id = 0; id < nd; id++) {
|
||||||
if (id == 0) {
|
if (id == 0) {
|
||||||
std::cout << " -ts ";
|
printf(" -ts ");
|
||||||
}
|
}
|
||||||
if (id > 0) {
|
printf("%s%" PRIu32, id > 0 ? "," : "", uint32_t(mparams.tensor_split[id]));
|
||||||
std::cout << ",";
|
|
||||||
}
|
|
||||||
std::cout << mparams.tensor_split[id];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||||
|
bool any_tbo = false;
|
||||||
for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) {
|
for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) {
|
||||||
if (itbo == 0) {
|
if (itbo == 0) {
|
||||||
std::cout << " -ot ";
|
printf(" -ot \"");
|
||||||
}
|
}
|
||||||
if (itbo > 0) {
|
printf("%s%s=%s", itbo > 0 ? "," : "", mparams.tensor_buft_overrides[itbo].pattern, ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft));
|
||||||
std::cout << ",";
|
any_tbo = true;
|
||||||
}
|
|
||||||
std::cout << mparams.tensor_buft_overrides[itbo].pattern << "=" << ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft);
|
|
||||||
}
|
}
|
||||||
std::cout << "\n";
|
printf("%s\n", any_tbo ? "\"" : "");
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -329,6 +329,7 @@ struct mtmd_context {
|
||||||
case PROJECTOR_TYPE_QWEN25O:
|
case PROJECTOR_TYPE_QWEN25O:
|
||||||
case PROJECTOR_TYPE_ULTRAVOX:
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
case PROJECTOR_TYPE_VOXTRAL:
|
case PROJECTOR_TYPE_VOXTRAL:
|
||||||
|
case PROJECTOR_TYPE_GLMA:
|
||||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <winsock2.h>
|
#include <winsock2.h>
|
||||||
|
|
@ -33,7 +34,8 @@
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define CMD_EXIT "exit"
|
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
|
||||||
|
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
|
||||||
|
|
||||||
// address for child process, this is needed because router may run on 0.0.0.0
|
// address for child process, this is needed because router may run on 0.0.0.0
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
|
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
|
||||||
|
|
@ -534,6 +536,8 @@ void server_models::load(const std::string & name) {
|
||||||
std::vector<char *> argv = to_char_ptr_array(child_args);
|
std::vector<char *> argv = to_char_ptr_array(child_args);
|
||||||
std::vector<char *> envp = to_char_ptr_array(child_env);
|
std::vector<char *> envp = to_char_ptr_array(child_env);
|
||||||
|
|
||||||
|
// TODO @ngxson : maybe separate stdout and stderr in the future
|
||||||
|
// so that we can use stdout for commands and stderr for logging
|
||||||
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
|
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
|
||||||
int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
|
int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
|
|
@ -547,11 +551,17 @@ void server_models::load(const std::string & name) {
|
||||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
|
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
|
||||||
// read stdout/stderr and forward to main server log
|
// read stdout/stderr and forward to main server log
|
||||||
|
bool state_received = false; // true if child state received
|
||||||
FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
|
FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
|
||||||
if (p_stdout_stderr) {
|
if (p_stdout_stderr) {
|
||||||
char buffer[4096];
|
char buffer[4096];
|
||||||
while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
|
while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
|
||||||
LOG("[%5d] %s", port, buffer);
|
LOG("[%5d] %s", port, buffer);
|
||||||
|
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
|
||||||
|
// child process is ready
|
||||||
|
this->update_status(name, SERVER_MODEL_STATUS_LOADED);
|
||||||
|
state_received = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
|
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
|
||||||
|
|
@ -595,7 +605,7 @@ static void interrupt_subprocess(FILE * stdin_file) {
|
||||||
// because subprocess.h does not provide a way to send SIGINT,
|
// because subprocess.h does not provide a way to send SIGINT,
|
||||||
// we will send a command to the child process to exit gracefully
|
// we will send a command to the child process to exit gracefully
|
||||||
if (stdin_file) {
|
if (stdin_file) {
|
||||||
fprintf(stdin_file, "%s\n", CMD_EXIT);
|
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||||
fflush(stdin_file);
|
fflush(stdin_file);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -707,32 +717,13 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
|
||||||
return proxy;
|
return proxy;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
|
std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
|
||||||
// send a notification to the router server that a model instance is ready
|
// send a notification to the router server that a model instance is ready
|
||||||
// TODO @ngxson : use HTTP client from libcommon
|
common_log_pause(common_log_main());
|
||||||
httplib::Client cli(base_params.hostname, router_port);
|
fflush(stdout);
|
||||||
cli.set_connection_timeout(0, 200000); // 200 milliseconds
|
fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
|
||||||
|
fflush(stdout);
|
||||||
httplib::Request req;
|
common_log_resume(common_log_main());
|
||||||
req.method = "POST";
|
|
||||||
req.path = "/models/status";
|
|
||||||
req.set_header("Content-Type", "application/json");
|
|
||||||
if (!base_params.api_keys.empty()) {
|
|
||||||
req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
json body;
|
|
||||||
body["model"] = name;
|
|
||||||
body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
|
|
||||||
req.body = body.dump();
|
|
||||||
|
|
||||||
SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
|
|
||||||
auto result = cli.send(std::move(req));
|
|
||||||
if (result.error() != httplib::Error::Success) {
|
|
||||||
auto err_str = httplib::to_string(result.error());
|
|
||||||
SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
|
|
||||||
exit(1); // force exit
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup thread for monitoring stdin
|
// setup thread for monitoring stdin
|
||||||
return std::thread([shutdown_handler]() {
|
return std::thread([shutdown_handler]() {
|
||||||
|
|
@ -746,7 +737,7 @@ std::thread server_models::setup_child_server(const common_params & base_params,
|
||||||
eof = true;
|
eof = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (line.find(CMD_EXIT) != std::string::npos) {
|
if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
|
||||||
SRV_INF("%s", "exit command received, exiting...\n");
|
SRV_INF("%s", "exit command received, exiting...\n");
|
||||||
shutdown_handler(0);
|
shutdown_handler(0);
|
||||||
break;
|
break;
|
||||||
|
|
@ -869,18 +860,6 @@ void server_models_routes::init_routes() {
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
// used by child process to notify the router about status change
|
|
||||||
// TODO @ngxson : maybe implement authentication for this endpoint in the future
|
|
||||||
this->post_router_models_status = [this](const server_http_req & req) {
|
|
||||||
auto res = std::make_unique<server_http_res>();
|
|
||||||
json body = json::parse(req.body);
|
|
||||||
std::string model = json_value(body, "model", std::string());
|
|
||||||
std::string value = json_value(body, "value", std::string());
|
|
||||||
models.update_status(model, server_model_status_from_string(value));
|
|
||||||
res_ok(res, {{"success", true}});
|
|
||||||
return res;
|
|
||||||
};
|
|
||||||
|
|
||||||
this->get_router_models = [this](const server_http_req &) {
|
this->get_router_models = [this](const server_http_req &) {
|
||||||
auto res = std::make_unique<server_http_res>();
|
auto res = std::make_unique<server_http_res>();
|
||||||
json models_json = json::array();
|
json models_json = json::array();
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,7 @@ public:
|
||||||
|
|
||||||
// notify the router server that a model instance is ready
|
// notify the router server that a model instance is ready
|
||||||
// return the monitoring thread (to be joined by the caller)
|
// return the monitoring thread (to be joined by the caller)
|
||||||
static std::thread setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler);
|
static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_models_routes {
|
struct server_models_routes {
|
||||||
|
|
@ -162,7 +162,6 @@ struct server_models_routes {
|
||||||
server_http_context::handler_t proxy_post;
|
server_http_context::handler_t proxy_post;
|
||||||
server_http_context::handler_t get_router_models;
|
server_http_context::handler_t get_router_models;
|
||||||
server_http_context::handler_t post_router_models_load;
|
server_http_context::handler_t post_router_models_load;
|
||||||
server_http_context::handler_t post_router_models_status;
|
|
||||||
server_http_context::handler_t post_router_models_unload;
|
server_http_context::handler_t post_router_models_unload;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,6 @@ int main(int argc, char ** argv, char ** envp) {
|
||||||
routes.get_models = models_routes->get_router_models;
|
routes.get_models = models_routes->get_router_models;
|
||||||
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
|
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
|
||||||
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
|
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
|
||||||
ctx_http.post("/models/status", ex_wrapper(models_routes->post_router_models_status));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||||
|
|
@ -291,7 +290,7 @@ int main(int argc, char ** argv, char ** envp) {
|
||||||
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
|
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
|
||||||
std::thread monitor_thread;
|
std::thread monitor_thread;
|
||||||
if (router_port != nullptr) {
|
if (router_port != nullptr) {
|
||||||
monitor_thread = server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler);
|
monitor_thread = server_models::setup_child_server(shutdown_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
// this call blocks the main thread until queue_tasks.terminate() is called
|
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue