Merge 6c1a8951f5 into 9dcac6cf9f
This commit is contained in:
commit
ee2e1aef54
|
|
@ -236,6 +236,7 @@ if (NOT GGML_BACKEND_DL)
|
||||||
llama_build_and_test(test-quantize-fns.cpp)
|
llama_build_and_test(test-quantize-fns.cpp)
|
||||||
llama_build_and_test(test-quantize-perf.cpp)
|
llama_build_and_test(test-quantize-perf.cpp)
|
||||||
llama_build_and_test(test-rope.cpp)
|
llama_build_and_test(test-rope.cpp)
|
||||||
|
llama_build_and_test(test-llama-batch.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# libmtmd
|
# libmtmd
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,565 @@
|
||||||
|
#include "../src/llama-batch.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <vector>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* llama_batch/sbatch/ubatch Test Program
|
||||||
|
* Tests the basic principles and functionality of batch processing
|
||||||
|
* Focuses on split_simple operation and state modifications
|
||||||
|
*
|
||||||
|
* Data Flow Diagram:
|
||||||
|
* ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||||
|
* │ llama_batch │───▶│ llama_sbatch │───▶│ llama_ubatch │
|
||||||
|
* │ (raw input) │ │ (sorted/grouped)│ │ (view/subset) │
|
||||||
|
* │ │ │ │ │ │
|
||||||
|
* │ token[]: [A,B,C]│ │ seq[]: groups │ │ token: ptr→data │
|
||||||
|
* │ pos[]: [0,1,2]│ │ ids[]: [0,1,2] │ │ n_tokens: count │
|
||||||
|
* │ seq_id: [0,0,0] │ │ offset: 0 │ │ equal_seqs: T/F │
|
||||||
|
* └─────────────────┘ │ length: 3 │ └─────────────────┘
|
||||||
|
* └─────────────────┘
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct test_scope {
|
||||||
|
const char * name;
|
||||||
|
explicit test_scope(const char * name) : name(name) {
|
||||||
|
std::cout << "\n╔══════════════════════════════════════════════════════════════════════════════════════╗\n";
|
||||||
|
std::cout << "║ " << std::left << std::setw(84) << name << " ║\n";
|
||||||
|
std::cout << "╚══════════════════════════════════════════════════════════════════════════════════════╝\n";
|
||||||
|
}
|
||||||
|
~test_scope() {
|
||||||
|
std::cout << "\n✅ " << name << " Test Completed Successfully\n";
|
||||||
|
std::cout << "═══════════════════════════════════════════════════════════════════════════════════════\n\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper function to print batch details
|
||||||
|
static void print_batch_details(const llama_batch& batch, const std::string& title) {
|
||||||
|
std::cout << "\n" << title << " Details:\n";
|
||||||
|
std::cout << "---------------------------------------------\n";
|
||||||
|
std::cout << "Total Tokens: " << batch.n_tokens << "\n";
|
||||||
|
|
||||||
|
if (batch.token) {
|
||||||
|
std::cout << "Tokens: ";
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
std::cout << batch.token[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.pos) {
|
||||||
|
std::cout << "Positions: ";
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
std::cout << batch.pos[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.n_seq_id && batch.seq_id) {
|
||||||
|
std::cout << "Sequence Details:\n";
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
std::cout << " Token[" << i << "]: seq_ids=[";
|
||||||
|
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||||
|
std::cout << batch.seq_id[i][j];
|
||||||
|
if (j < batch.n_seq_id[i] - 1) std::cout << ",";
|
||||||
|
}
|
||||||
|
std::cout << "]\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.logits) {
|
||||||
|
std::cout << "Output Flags: ";
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
std::cout << (int)batch.logits[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
std::cout << "---------------------------------------------\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to print sbatch details
|
||||||
|
static void print_sbatch_details(const llama_sbatch& sbatch, const std::string& title) {
|
||||||
|
std::cout << "\n" << title << " Details:\n";
|
||||||
|
std::cout << "---------------------------------------------\n";
|
||||||
|
std::cout << "Total Tokens: " << sbatch.n_tokens << "\n";
|
||||||
|
std::cout << "Sequences: " << sbatch.seq.size() << "\n";
|
||||||
|
|
||||||
|
for (size_t i = 0; i < sbatch.seq.size(); ++i) {
|
||||||
|
const auto& s = sbatch.seq[i];
|
||||||
|
std::cout << "Sequence[" << i << "]: "
|
||||||
|
<< "offset=" << s.offset
|
||||||
|
<< ", length=" << s.length << "\n";
|
||||||
|
|
||||||
|
if (s.seq_id && s.n_seq_id > 0) {
|
||||||
|
std::cout << " Sequence IDs: [";
|
||||||
|
for (int j = 0; j < s.n_seq_id; ++j) {
|
||||||
|
std::cout << s.seq_id[j];
|
||||||
|
if (j < s.n_seq_id - 1) std::cout << ",";
|
||||||
|
}
|
||||||
|
std::cout << "]\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Sorted Token Order: ";
|
||||||
|
for (size_t i = 0; i < sbatch.ids.size(); ++i) {
|
||||||
|
std::cout << sbatch.ids[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
std::cout << "---------------------------------------------\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to print ubatch details
|
||||||
|
static void print_ubatch_details(const llama_ubatch& ubatch, const std::string& title) {
|
||||||
|
std::cout << "\n" << title << " Details:\n";
|
||||||
|
std::cout << "---------------------------------------------\n";
|
||||||
|
std::cout << "Equal Sequences: " << (ubatch.equal_seqs ? "true" : "false") << "\n";
|
||||||
|
std::cout << "Total Tokens: " << ubatch.n_tokens << "\n";
|
||||||
|
std::cout << "Tokens per Sequence: " << ubatch.n_seq_tokens << "\n";
|
||||||
|
std::cout << "Number of Sequences: " << ubatch.n_seqs << "\n";
|
||||||
|
|
||||||
|
if (ubatch.token) {
|
||||||
|
std::cout << "Tokens: ";
|
||||||
|
for (size_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
|
std::cout << ubatch.token[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ubatch.pos) {
|
||||||
|
std::cout << "Positions: ";
|
||||||
|
for (size_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
|
std::cout << ubatch.pos[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ubatch.n_seq_id) {
|
||||||
|
std::cout << "Sequence ID Details: ";
|
||||||
|
if (ubatch.equal_seqs) {
|
||||||
|
for (size_t i = 0; i < ubatch.n_seqs; ++i) {
|
||||||
|
std::cout << ubatch.n_seq_id[i] << " ";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
|
std::cout << ubatch.n_seq_id[i] << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ubatch.output) {
|
||||||
|
std::cout << "Output Flags: ";
|
||||||
|
for (size_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
|
std::cout << (int)ubatch.output[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
std::cout << "---------------------------------------------\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 1: Basic Batch Creation and Conversion
|
||||||
|
static void test_basic_batch_conversion() {
|
||||||
|
test_scope scope("Basic Batch Creation and Conversion");
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Basic Conversion Flow:
|
||||||
|
*
|
||||||
|
* llama_batch (raw input):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 100 │ 101 │ 102 │ 103 │ 104 │ ← tokens
|
||||||
|
* │ 0 │ 1 │ 2 │ 3 │ 4 │ ← positions
|
||||||
|
* │ 0 │ 0 │ 0 │ 0 │ 0 │ ← seq_id
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↓
|
||||||
|
* llama_sbatch (simple_split=true):
|
||||||
|
* ┌─────────────────────────────────┐
|
||||||
|
* │ seq[0]: {n_seq_id=0, offset=0, │
|
||||||
|
* │ length=5} │
|
||||||
|
* │ ids[]: [0,1,2,3,4] │
|
||||||
|
* └─────────────────────────────────┘
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Create a simple batch with 5 tokens in one sequence
|
||||||
|
llama_batch batch = llama_batch_init(10, 0, 2); // max 10 tokens, no embeddings, max 2 seqs
|
||||||
|
|
||||||
|
// Add tokens to sequence 0
|
||||||
|
llama_seq_id seq_0 = 0;
|
||||||
|
common_batch_add(batch, 100, 0, {seq_0}, false); // token 100 at pos 0
|
||||||
|
common_batch_add(batch, 101, 1, {seq_0}, false); // token 101 at pos 1
|
||||||
|
common_batch_add(batch, 102, 2, {seq_0}, false); // token 102 at pos 2
|
||||||
|
common_batch_add(batch, 103, 3, {seq_0}, false); // token 103 at pos 3
|
||||||
|
common_batch_add(batch, 104, 4, {seq_0}, true); // token 104 at pos 4, output=true
|
||||||
|
|
||||||
|
print_batch_details(batch, "Original Batch");
|
||||||
|
|
||||||
|
// Convert to sbatch with simple split mode
|
||||||
|
llama_sbatch sbatch(batch, 64, true, false); // n_embd=64, simple_split=true, logits_all=false
|
||||||
|
|
||||||
|
print_sbatch_details(sbatch, "Simple Split SBatch");
|
||||||
|
|
||||||
|
// Verify that simple split creates one sequence with n_seq_id = 0
|
||||||
|
GGML_ASSERT(sbatch.seq.size() == 1);
|
||||||
|
GGML_ASSERT(sbatch.seq[0].n_seq_id == 0);
|
||||||
|
GGML_ASSERT(sbatch.seq[0].length == 5);
|
||||||
|
GGML_ASSERT(sbatch.seq[0].offset == 0);
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Testing split_simple Operation and State Modification
|
||||||
|
static void test_split_simple_modification() {
|
||||||
|
test_scope scope("Split Simple Operation and State Modification");
|
||||||
|
|
||||||
|
/*
|
||||||
|
* split_simple State Modification Visualization:
|
||||||
|
*
|
||||||
|
* Initial sbatch state:
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │ ← token data
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ▲ ▲
|
||||||
|
* offset=0 offset+length=6
|
||||||
|
*
|
||||||
|
* After split_simple(2):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑consumed↑ ▲ ▲
|
||||||
|
* offset=2 offset+length=6
|
||||||
|
*
|
||||||
|
* After split_simple(3):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑─── consumed ────↑ ▲ ▲
|
||||||
|
* offset=5 offset+length=6
|
||||||
|
*
|
||||||
|
* Key insight: split_simple "consumes" tokens from the head by advancing offset!
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Create a batch with 6 tokens
|
||||||
|
llama_batch batch = llama_batch_init(10, 0, 1);
|
||||||
|
|
||||||
|
llama_seq_id seq_0 = 0;
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
// is_logits?
|
||||||
|
common_batch_add(batch, 200 + i, i, {seq_0}, i == 5); // last token outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
print_batch_details(batch, "Original Batch (6 tokens)");
|
||||||
|
|
||||||
|
// Convert to sbatch
|
||||||
|
llama_sbatch sbatch(batch, 64, true, false);
|
||||||
|
|
||||||
|
print_sbatch_details(sbatch, "Initial SBatch State");
|
||||||
|
|
||||||
|
std::cout << "\n=== Testing Multiple split_simple Calls ===\n";
|
||||||
|
|
||||||
|
// First split_simple call - take 2 tokens
|
||||||
|
std::cout << "\n--- First split_simple(2) ---\n";
|
||||||
|
std::cout << "Before split_simple:\n";
|
||||||
|
std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n";
|
||||||
|
std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n";
|
||||||
|
std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Visual representation of split_simple(2):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑─ extract these 2 ─↑ ↑─ remaining ─↑
|
||||||
|
* → ubatch1 → sbatch.seq[0]
|
||||||
|
*/
|
||||||
|
|
||||||
|
llama_ubatch ubatch1 = sbatch.split_simple(2);
|
||||||
|
|
||||||
|
std::cout << "After split_simple:\n";
|
||||||
|
std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n";
|
||||||
|
std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n";
|
||||||
|
std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n";
|
||||||
|
|
||||||
|
print_ubatch_details(ubatch1, "First UBatch (2 tokens)");
|
||||||
|
|
||||||
|
// Verify the modifications
|
||||||
|
GGML_ASSERT(sbatch.seq[0].offset == 2); // offset advanced by 2
|
||||||
|
GGML_ASSERT(sbatch.seq[0].length == 4); // length reduced by 2
|
||||||
|
GGML_ASSERT(sbatch.n_tokens == 4); // total tokens reduced by 2
|
||||||
|
GGML_ASSERT(ubatch1.n_tokens == 2); // ubatch contains 2 tokens
|
||||||
|
|
||||||
|
// Second split_simple call - take 3 tokens
|
||||||
|
std::cout << "\n--- Second split_simple(3) ---\n";
|
||||||
|
std::cout << "Before split_simple:\n";
|
||||||
|
std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n";
|
||||||
|
std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n";
|
||||||
|
std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Visual representation of split_simple(3):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑─consumed─↑ ↑─extract these 3─↑↑─remaining─↑
|
||||||
|
* → ubatch2 → sbatch.seq[0]
|
||||||
|
*/
|
||||||
|
|
||||||
|
llama_ubatch ubatch2 = sbatch.split_simple(3);
|
||||||
|
|
||||||
|
std::cout << "After split_simple:\n";
|
||||||
|
std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n";
|
||||||
|
std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n";
|
||||||
|
std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n";
|
||||||
|
|
||||||
|
print_ubatch_details(ubatch2, "Second UBatch (3 tokens)");
|
||||||
|
|
||||||
|
// Verify the modifications
|
||||||
|
GGML_ASSERT(sbatch.seq[0].offset == 5); // offset advanced by 3 more
|
||||||
|
GGML_ASSERT(sbatch.seq[0].length == 1); // length reduced by 3 more
|
||||||
|
GGML_ASSERT(sbatch.n_tokens == 1); // total tokens reduced by 3 more
|
||||||
|
GGML_ASSERT(ubatch2.n_tokens == 3); // ubatch contains 3 tokens
|
||||||
|
|
||||||
|
// Third split_simple call - take remaining token
|
||||||
|
std::cout << "\n--- Third split_simple(10) (should only get 1 token) ---\n";
|
||||||
|
std::cout << "Before split_simple:\n";
|
||||||
|
std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n";
|
||||||
|
std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n";
|
||||||
|
std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Visual representation - requesting more than available:
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑─────consumed──────────────↑ ↑only 1↑
|
||||||
|
* remaining
|
||||||
|
*/
|
||||||
|
|
||||||
|
llama_ubatch ubatch3 = sbatch.split_simple(10); // Request more than available
|
||||||
|
|
||||||
|
std::cout << "After split_simple:\n";
|
||||||
|
std::cout << " seq[0].offset = " << sbatch.seq[0].offset << "\n";
|
||||||
|
std::cout << " seq[0].length = " << sbatch.seq[0].length << "\n";
|
||||||
|
std::cout << " sbatch.n_tokens = " << sbatch.n_tokens << "\n";
|
||||||
|
|
||||||
|
print_ubatch_details(ubatch3, "Third UBatch (1 token)");
|
||||||
|
|
||||||
|
// Verify the modifications
|
||||||
|
GGML_ASSERT(sbatch.seq[0].offset == 6); // offset advanced by 1 more
|
||||||
|
GGML_ASSERT(sbatch.seq[0].length == 0); // length reduced to 0
|
||||||
|
GGML_ASSERT(sbatch.n_tokens == 0); // no more tokens
|
||||||
|
GGML_ASSERT(ubatch3.n_tokens == 1); // ubatch contains 1 token
|
||||||
|
|
||||||
|
// Fourth split_simple call - should return empty ubatch
|
||||||
|
std::cout << "\n--- Fourth split_simple(1) (should be empty) ---\n";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Visual representation - nothing left:
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 200 │ 201 │ 202 │ 203 │ 204 │ 205 │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑─────────all consumed────────────↑
|
||||||
|
* offset=6, length=0
|
||||||
|
*/
|
||||||
|
|
||||||
|
llama_ubatch ubatch4 = sbatch.split_simple(1);
|
||||||
|
print_ubatch_details(ubatch4, "Fourth UBatch (empty)");
|
||||||
|
|
||||||
|
GGML_ASSERT(ubatch4.n_tokens == 0); // no tokens available
|
||||||
|
|
||||||
|
std::cout << "\n✓ All state modifications verified correctly!\n";
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Multi-Sequence Batch Processing
|
||||||
|
static void test_multi_sequence_batch() {
|
||||||
|
test_scope scope("Multi-Sequence Batch Processing");
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Multi-Sequence Processing Visualization:
|
||||||
|
*
|
||||||
|
* Original batch (mixed sequences):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 300 │ 301 │ 302 │ 400 │ 401 │ 500 │ 999 │
|
||||||
|
* │seq:0│seq:0│seq:0│seq:1│seq:1│seq:2│0&1 │
|
||||||
|
* │pos:0│pos:1│pos:2│pos:0│pos:1│pos:0│pos:10│
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
*
|
||||||
|
* After sbatch sorting (complex mode):
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 999 │ 300 │ 301 │ 302 │ 400 │ 401 │ 500 │
|
||||||
|
* │0&1 │seq:0│seq:0│seq:0│seq:1│seq:1│seq:2│
|
||||||
|
* │pos:10│pos:0│pos:1│pos:2│pos:0│pos:1│pos:0│
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑ ↑─────seq 0──────↑ ↑─seq 1─↑ ↑seq2↑
|
||||||
|
* shared (sorted by pos)
|
||||||
|
* prompt
|
||||||
|
*
|
||||||
|
* Simple split mode treats everything as one sequence:
|
||||||
|
* ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐
|
||||||
|
* │ 300 │ 301 │ 302 │ 400 │ 401 │ 500 │ 999 │
|
||||||
|
* │ │ │ │ │ │ │ │
|
||||||
|
* └─────┴─────┴─────┴─────┴─────┴─────┴─────┘
|
||||||
|
* ↑─────────all treated as seq_id=0──────────↑
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Create a batch with multiple sequences
|
||||||
|
llama_batch batch = llama_batch_init(20, 0, 3);
|
||||||
|
|
||||||
|
llama_seq_id seq_0 = 0;
|
||||||
|
llama_seq_id seq_1 = 1;
|
||||||
|
llama_seq_id seq_2 = 2;
|
||||||
|
|
||||||
|
// Add tokens to different sequences
|
||||||
|
common_batch_add(batch, 300, 0, {seq_0}, false); // seq_0: pos 0
|
||||||
|
common_batch_add(batch, 301, 1, {seq_0}, false); // seq_0: pos 1
|
||||||
|
common_batch_add(batch, 302, 2, {seq_0}, true); // seq_0: pos 2, output
|
||||||
|
|
||||||
|
common_batch_add(batch, 400, 0, {seq_1}, false); // seq_1: pos 0
|
||||||
|
common_batch_add(batch, 401, 1, {seq_1}, true); // seq_1: pos 1, output
|
||||||
|
|
||||||
|
common_batch_add(batch, 500, 0, {seq_2}, true); // seq_2: pos 0, output
|
||||||
|
|
||||||
|
// Add a shared prompt token (belongs to multiple sequences)
|
||||||
|
common_batch_add(batch, 999, 10, {seq_0, seq_1}, false); // shared between seq_0 and seq_1
|
||||||
|
|
||||||
|
print_batch_details(batch, "Multi-Sequence Batch");
|
||||||
|
|
||||||
|
// Convert to sbatch with complex split mode (simple_split=false)
|
||||||
|
llama_sbatch sbatch_complex(batch, 64, false, false);
|
||||||
|
|
||||||
|
print_sbatch_details(sbatch_complex, "Complex SBatch (sorted by seq_id)");
|
||||||
|
|
||||||
|
std::cout << "\n=== Testing split_equal and split_seq ===\n";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* split_equal strategy:
|
||||||
|
* - Processes sequences by equal-length batches
|
||||||
|
* - Shared prompts processed first (highest priority)
|
||||||
|
* - Equal length sequences grouped together
|
||||||
|
*
|
||||||
|
* split_seq strategy:
|
||||||
|
* - Processes one sequence at a time
|
||||||
|
* - Takes from the end of sequence list
|
||||||
|
* - Good for sequential processing
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Test split_equal
|
||||||
|
llama_ubatch ubatch_equal = sbatch_complex.split_equal(10);
|
||||||
|
print_ubatch_details(ubatch_equal, "Split Equal Result");
|
||||||
|
|
||||||
|
// Test split_seq
|
||||||
|
llama_ubatch ubatch_seq = sbatch_complex.split_seq(5);
|
||||||
|
print_ubatch_details(ubatch_seq, "Split Seq Result");
|
||||||
|
|
||||||
|
// Compare with simple split approach
|
||||||
|
llama_sbatch sbatch_simple(batch, 64, true, false);
|
||||||
|
print_sbatch_details(sbatch_simple, "Simple SBatch");
|
||||||
|
|
||||||
|
llama_ubatch ubatch_simple = sbatch_simple.split_simple(10);
|
||||||
|
print_ubatch_details(ubatch_simple, "Simple Split Result");
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Edge Cases and Error Conditions
|
||||||
|
static void test_edge_cases() {
|
||||||
|
test_scope scope("Edge Cases and Error Conditions");
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Edge Case Testing:
|
||||||
|
*
|
||||||
|
* Empty batch:
|
||||||
|
* ┌─┐
|
||||||
|
* │ │ ← no tokens
|
||||||
|
* └─┘
|
||||||
|
*
|
||||||
|
* Single token batch:
|
||||||
|
* ┌─────┐
|
||||||
|
* │ 777 │ ← one token
|
||||||
|
* └─────┘
|
||||||
|
*
|
||||||
|
* After split:
|
||||||
|
* ┌─┐
|
||||||
|
* │ │ ← empty sbatch
|
||||||
|
* └─┘
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Test empty batch
|
||||||
|
llama_batch empty_batch = llama_batch_init(5, 0, 1);
|
||||||
|
// Don't add any tokens
|
||||||
|
|
||||||
|
print_batch_details(empty_batch, "Empty Batch");
|
||||||
|
|
||||||
|
llama_sbatch empty_sbatch(empty_batch, 64, true, false);
|
||||||
|
print_sbatch_details(empty_sbatch, "Empty SBatch");
|
||||||
|
|
||||||
|
llama_ubatch empty_ubatch = empty_sbatch.split_simple(5);
|
||||||
|
print_ubatch_details(empty_ubatch, "Empty UBatch from split_simple");
|
||||||
|
|
||||||
|
GGML_ASSERT(empty_ubatch.n_tokens == 0);
|
||||||
|
GGML_ASSERT(empty_sbatch.seq.empty());
|
||||||
|
|
||||||
|
// Test single token batch
|
||||||
|
llama_batch single_batch = llama_batch_init(5, 0, 1);
|
||||||
|
common_batch_add(single_batch, 777, 0, {0}, true);
|
||||||
|
|
||||||
|
print_batch_details(single_batch, "Single Token Batch");
|
||||||
|
|
||||||
|
llama_sbatch single_sbatch(single_batch, 64, true, false);
|
||||||
|
print_sbatch_details(single_sbatch, "Single Token SBatch");
|
||||||
|
|
||||||
|
llama_ubatch single_ubatch = single_sbatch.split_simple(1);
|
||||||
|
print_ubatch_details(single_ubatch, "Single Token UBatch");
|
||||||
|
|
||||||
|
GGML_ASSERT(single_ubatch.n_tokens == 1);
|
||||||
|
GGML_ASSERT(single_ubatch.token[0] == 777);
|
||||||
|
|
||||||
|
// After split, sbatch should be empty
|
||||||
|
llama_ubatch post_split_ubatch = single_sbatch.split_simple(1);
|
||||||
|
GGML_ASSERT(post_split_ubatch.n_tokens == 0);
|
||||||
|
|
||||||
|
llama_batch_free(empty_batch);
|
||||||
|
llama_batch_free(single_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
std::cout << "llama_batch/sbatch/ubatch Test Program\n";
|
||||||
|
std::cout << "=====================================\n";
|
||||||
|
std::cout << "Testing batch processing principles and split_simple modifications\n";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Overall Test Architecture:
|
||||||
|
*
|
||||||
|
* ┌─────────────────────────┐
|
||||||
|
* │ Input Validation │
|
||||||
|
* │ (test_basic_batch_*) │
|
||||||
|
* └───────────┬─────────────┘
|
||||||
|
* ▼
|
||||||
|
* ┌─────────────────────────┐
|
||||||
|
* │ Core Functionality │
|
||||||
|
* │(test_split_simple_*) │ ← Main focus: state modification
|
||||||
|
* └───────────┬─────────────┘
|
||||||
|
* ▼
|
||||||
|
* ┌─────────────────────────┐
|
||||||
|
* │ Complex Scenarios │
|
||||||
|
* │(test_multi_sequence_*) │
|
||||||
|
* └───────────┬─────────────┘
|
||||||
|
* ▼
|
||||||
|
* ┌─────────────────────────┐
|
||||||
|
* │ Edge Cases & │
|
||||||
|
* │ Data Integrity │
|
||||||
|
* └─────────────────────────┘
|
||||||
|
*/
|
||||||
|
|
||||||
|
test_basic_batch_conversion();
|
||||||
|
test_split_simple_modification();
|
||||||
|
test_multi_sequence_batch();
|
||||||
|
test_edge_cases();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue