vulkan: fix top_k bug when there are ties in the input (#17659)

* vulkan: Reduce temporary memory usage for TOP_K

- Compute row size for the temp buffer based on the output of the first pass.
- Update shader addressing math to use the output row size
- Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k"

For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer
from about 3.2MB to 500KB.

* vulkan: fix top_k bug when there are ties in the input

I noticed by inspection a bug in the vulkan top_k shader where if the least
value in the top_k appears multiple times we could end up writing those extra
copies out rather than some larger values (if the larger values are on higher
numbered threads).

I rewrote the test verification to handle this case, where the final index set
is not necessarily the same.

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Jeff Bolz 2025-12-05 15:03:19 -06:00 committed by GitHub
parent e15cd06a94
commit a0f3897d53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 139 additions and 38 deletions

View File

@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
sizeof(int) * device->subgroup_size +
2 * sizeof(int) +
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);

View File

@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
@ -156,25 +157,66 @@ void topk(const uint row) {
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
barrier();
uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
// Values strictly greater than range_min must be stored. For values equal
// to range_min, there can be ties and it's possible we'll need to store
// an arbitrary subset of them.
// If total == p.k, have a fast path where we don't need to handle ties.
if (total == p.k) {
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
}
barrier();
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
}
}
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
}
} else {
bool top = f2ui(intBitsToFloat(v.y)) > range_min;
bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
uvec4 b_top = subgroupBallot(top);
uvec4 b_eq_min = subgroupBallot(eq_min);
uint bit_count_top = subgroupBallotBitCount(b_top);
uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
}
barrier();
uint out_idx = 0;
uint eq_min_base = 0;
uint eq_min_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
eq_min_idx += eq_min_partials[i];
}
eq_min_base += offset_partials[i];
}
// range_min values are stored at the end
eq_min_idx += eq_min_base;
uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex_top] = v;
}
if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
}
}
barrier();

View File

@ -286,10 +286,11 @@ static double nmse(const float * a, const float * b, size_t n) {
return mse_a_b / mse_a_0;
}
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
std::unordered_map<int32_t, size_t> set_a;
std::unordered_map<int32_t, size_t> set_b;
// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)
template <typename T>
static double jdst(const T * a, const T * b, size_t n) {
std::unordered_map<T, size_t> set_a;
std::unordered_map<T, size_t> set_b;
for (size_t i = 0; i < n; ++i) {
set_a[a[i]]++;
@ -5001,42 +5002,94 @@ struct test_top_k : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const int k;
const bool ties;
ggml_tensor * input {};
std::string vars() override {
return VARS_TO_STR3(type, ne, k);
return VARS_TO_STR4(type, ne, k, ties);
}
test_top_k(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {16, 10, 10, 10},
int k = 4)
: type(type), ne(ne), k(k) {}
int k = 4, bool ties = false)
: type(type), ne(ne), k(k), ties(ties) {}
double max_err() override {
return 0.0;
}
// When there are ties, only validate the final result.
// The logic in err can't handle the sentinel tensors.
bool run_whole_graph() override { return ties; }
double err(const float * a, const float * b, size_t n) override {
std::vector<int32_t> ia(n);
std::vector<int32_t> ib(n);
// When there are no ties, we expect the exact same set of indices,
// but possibly in a different order. When there are ties, the indices
// can be different but the input values they correspond to should be
// the same. The logic for ties could work for non-ties, but only for
// the output tensor, not for the sentinel tensors.
if (ties) {
std::vector<float> src(ggml_nelements(input));
double diff = 0.0f;
ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));
for (size_t i = 0; i < n; i++) {
ia[i] = (int32_t) a[i];
ib[i] = (int32_t) b[i];
double diff = 0.0f;
// penalize the result if the data is not integer valued
diff += std::fabs(a[i] - ia[i]);
diff += std::fabs(b[i] - ib[i]);
GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));
int64_t cols = input->ne[0];
std::vector<int32_t> ia(k);
std::vector<int32_t> ib(k);
std::vector<float> asrc(k);
std::vector<float> bsrc(k);
for (int64_t r = 0; r < ggml_nrows(input); r++) {
// Convert indices for the row back to integer
for (int64_t c = 0; c < k; c++) {
ia[c] = (int32_t)a[r * k + c];
ib[c] = (int32_t)b[r * k + c];
}
// The src values for each row should match.
for (int64_t c = 0; c < k; c++) {
asrc[c] = src[r * cols + ia[c]];
bsrc[c] = src[r * cols + ib[c]];
}
diff += jdst(asrc.data(), bsrc.data(), k);
// There should be no duplicate indices
std::sort(ia.begin(), ia.end());
std::sort(ib.begin(), ib.end());
if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {
diff += 1;
}
if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {
diff += 1;
}
}
return diff;
} else {
std::vector<int32_t> ia(n);
std::vector<int32_t> ib(n);
double diff = 0.0f;
for (size_t i = 0; i < n; i++) {
ia[i] = (int32_t) a[i];
ib[i] = (int32_t) b[i];
// penalize the result if the data is not integer valued
diff += std::fabs(a[i] - ia[i]);
diff += std::fabs(b[i] - ib[i]);
}
return diff + jdst(ia.data(), ib.data(), n);
}
return diff + jdst(ia.data(), ib.data(), n);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
// Save 'a' for err()
input = a;
ggml_tensor * out = ggml_top_k(ctx, a, k);
ggml_set_name(out, "out");
@ -5047,11 +5100,16 @@ struct test_top_k : public test_case {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// initialize with unique values to avoid ties
int tie_denom = std::max(1, std::min(10, k / 2));
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
if (ties) {
// integer division to introduce duplicates
data[i] = i / tie_denom;
} else {
data[i] = i;
}
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
@ -7657,6 +7715,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
if (k <= 1<<i) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));
}
}
}