vulkan: fix top_k bug when there are ties in the input (#17659)
* vulkan: Reduce temporary memory usage for TOP_K - Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB. * vulkan: fix top_k bug when there are ties in the input I noticed by inspection a bug in the vulkan top_k shader where if the least value in the top_k appears multiple times we could end up writing those extra copies out rather than some larger values (if the larger values are on higher numbered threads). I rewrote the test verification to handle this case, where the final index set is not necessarily the same. * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
e15cd06a94
commit
a0f3897d53
|
|
@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
||||
sizeof(int) * device->subgroup_size +
|
||||
2 * sizeof(int) +
|
||||
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||
2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
||||
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
|
|||
shared int sh_min_idx;
|
||||
shared uint sh_total;
|
||||
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
|
||||
// Map float values to uint such that comparisons still work.
|
||||
// Positive values set the high bit, negative values are inverted.
|
||||
|
|
@ -156,25 +157,66 @@ void topk(const uint row) {
|
|||
// We need to compact these values to the start of the dst_row array.
|
||||
// Have each subgroup count how many items it'll store, so other
|
||||
// subgroups can compute their base offset.
|
||||
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||
uvec4 b = subgroupBallot(top);
|
||||
uint bit_count = subgroupBallotBitCount(b);
|
||||
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint out_idx = 0;
|
||||
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||
if (i < tid / SUBGROUP_SIZE) {
|
||||
out_idx += offset_partials[i];
|
||||
// Values strictly greater than range_min must be stored. For values equal
|
||||
// to range_min, there can be ties and it's possible we'll need to store
|
||||
// an arbitrary subset of them.
|
||||
// If total == p.k, have a fast path where we don't need to handle ties.
|
||||
if (total == p.k) {
|
||||
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||
uvec4 b = subgroupBallot(top);
|
||||
uint bit_count = subgroupBallotBitCount(b);
|
||||
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
||||
if (top) {
|
||||
// TODO: Copy directly to the output?
|
||||
dst_row[out_idx + bit_count_ex] = v;
|
||||
uint out_idx = 0;
|
||||
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||
if (i < tid / SUBGROUP_SIZE) {
|
||||
out_idx += offset_partials[i];
|
||||
}
|
||||
}
|
||||
|
||||
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
||||
if (top) {
|
||||
// TODO: Copy directly to the output?
|
||||
dst_row[out_idx + bit_count_ex] = v;
|
||||
}
|
||||
} else {
|
||||
bool top = f2ui(intBitsToFloat(v.y)) > range_min;
|
||||
bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
|
||||
uvec4 b_top = subgroupBallot(top);
|
||||
uvec4 b_eq_min = subgroupBallot(eq_min);
|
||||
uint bit_count_top = subgroupBallotBitCount(b_top);
|
||||
uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
|
||||
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||
offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
|
||||
eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint out_idx = 0;
|
||||
uint eq_min_base = 0;
|
||||
uint eq_min_idx = 0;
|
||||
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||
if (i < tid / SUBGROUP_SIZE) {
|
||||
out_idx += offset_partials[i];
|
||||
eq_min_idx += eq_min_partials[i];
|
||||
}
|
||||
eq_min_base += offset_partials[i];
|
||||
}
|
||||
// range_min values are stored at the end
|
||||
eq_min_idx += eq_min_base;
|
||||
|
||||
uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
|
||||
uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
|
||||
if (top) {
|
||||
// TODO: Copy directly to the output?
|
||||
dst_row[out_idx + bit_count_ex_top] = v;
|
||||
}
|
||||
if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
|
||||
dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
|
|
|||
|
|
@ -286,10 +286,11 @@ static double nmse(const float * a, const float * b, size_t n) {
|
|||
return mse_a_b / mse_a_0;
|
||||
}
|
||||
|
||||
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
|
||||
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
|
||||
std::unordered_map<int32_t, size_t> set_a;
|
||||
std::unordered_map<int32_t, size_t> set_b;
|
||||
// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)
|
||||
template <typename T>
|
||||
static double jdst(const T * a, const T * b, size_t n) {
|
||||
std::unordered_map<T, size_t> set_a;
|
||||
std::unordered_map<T, size_t> set_b;
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
set_a[a[i]]++;
|
||||
|
|
@ -5001,42 +5002,94 @@ struct test_top_k : public test_case {
|
|||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int k;
|
||||
const bool ties;
|
||||
ggml_tensor * input {};
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, k);
|
||||
return VARS_TO_STR4(type, ne, k, ties);
|
||||
}
|
||||
|
||||
test_top_k(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {16, 10, 10, 10},
|
||||
int k = 4)
|
||||
: type(type), ne(ne), k(k) {}
|
||||
int k = 4, bool ties = false)
|
||||
: type(type), ne(ne), k(k), ties(ties) {}
|
||||
|
||||
double max_err() override {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// When there are ties, only validate the final result.
|
||||
// The logic in err can't handle the sentinel tensors.
|
||||
bool run_whole_graph() override { return ties; }
|
||||
|
||||
double err(const float * a, const float * b, size_t n) override {
|
||||
std::vector<int32_t> ia(n);
|
||||
std::vector<int32_t> ib(n);
|
||||
// When there are no ties, we expect the exact same set of indices,
|
||||
// but possibly in a different order. When there are ties, the indices
|
||||
// can be different but the input values they correspond to should be
|
||||
// the same. The logic for ties could work for non-ties, but only for
|
||||
// the output tensor, not for the sentinel tensors.
|
||||
if (ties) {
|
||||
std::vector<float> src(ggml_nelements(input));
|
||||
|
||||
double diff = 0.0f;
|
||||
ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ia[i] = (int32_t) a[i];
|
||||
ib[i] = (int32_t) b[i];
|
||||
double diff = 0.0f;
|
||||
|
||||
// penalize the result if the data is not integer valued
|
||||
diff += std::fabs(a[i] - ia[i]);
|
||||
diff += std::fabs(b[i] - ib[i]);
|
||||
GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));
|
||||
int64_t cols = input->ne[0];
|
||||
std::vector<int32_t> ia(k);
|
||||
std::vector<int32_t> ib(k);
|
||||
std::vector<float> asrc(k);
|
||||
std::vector<float> bsrc(k);
|
||||
for (int64_t r = 0; r < ggml_nrows(input); r++) {
|
||||
// Convert indices for the row back to integer
|
||||
for (int64_t c = 0; c < k; c++) {
|
||||
ia[c] = (int32_t)a[r * k + c];
|
||||
ib[c] = (int32_t)b[r * k + c];
|
||||
}
|
||||
// The src values for each row should match.
|
||||
for (int64_t c = 0; c < k; c++) {
|
||||
asrc[c] = src[r * cols + ia[c]];
|
||||
bsrc[c] = src[r * cols + ib[c]];
|
||||
}
|
||||
diff += jdst(asrc.data(), bsrc.data(), k);
|
||||
// There should be no duplicate indices
|
||||
std::sort(ia.begin(), ia.end());
|
||||
std::sort(ib.begin(), ib.end());
|
||||
if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {
|
||||
diff += 1;
|
||||
}
|
||||
if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {
|
||||
diff += 1;
|
||||
}
|
||||
}
|
||||
return diff;
|
||||
} else {
|
||||
std::vector<int32_t> ia(n);
|
||||
std::vector<int32_t> ib(n);
|
||||
|
||||
double diff = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ia[i] = (int32_t) a[i];
|
||||
ib[i] = (int32_t) b[i];
|
||||
|
||||
// penalize the result if the data is not integer valued
|
||||
diff += std::fabs(a[i] - ia[i]);
|
||||
diff += std::fabs(b[i] - ib[i]);
|
||||
}
|
||||
|
||||
return diff + jdst(ia.data(), ib.data(), n);
|
||||
}
|
||||
|
||||
return diff + jdst(ia.data(), ib.data(), n);
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
// Save 'a' for err()
|
||||
input = a;
|
||||
|
||||
ggml_tensor * out = ggml_top_k(ctx, a, k);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
|
|
@ -5047,11 +5100,16 @@ struct test_top_k : public test_case {
|
|||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
// initialize with unique values to avoid ties
|
||||
int tie_denom = std::max(1, std::min(10, k / 2));
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<float> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i;
|
||||
if (ties) {
|
||||
// integer division to introduce duplicates
|
||||
data[i] = i / tie_denom;
|
||||
} else {
|
||||
data[i] = i;
|
||||
}
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||
|
|
@ -7657,6 +7715,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
if (k <= 1<<i) {
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue