diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f57199b1ad..4a6ec1be77 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + sizeof(int) * device->subgroup_size + 2 * sizeof(int) + - (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + 2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int); if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index f794285ee1..0b757f38e1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE]; shared int sh_min_idx; shared uint sh_total; shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; +shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE]; // Map float values to uint such that comparisons still work. // Positive values set the high bit, negative values are inverted. @@ -156,25 +157,66 @@ void topk(const uint row) { // We need to compact these values to the start of the dst_row array. // Have each subgroup count how many items it'll store, so other // subgroups can compute their base offset. - bool top = f2ui(intBitsToFloat(v.y)) >= range_min; - uvec4 b = subgroupBallot(top); - uint bit_count = subgroupBallotBitCount(b); - if ((tid % SUBGROUP_SIZE) == 0) { - offset_partials[tid / SUBGROUP_SIZE] = bit_count; - } - barrier(); - - uint out_idx = 0; - [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { - if (i < tid / SUBGROUP_SIZE) { - out_idx += offset_partials[i]; + // Values strictly greater than range_min must be stored. For values equal + // to range_min, there can be ties and it's possible we'll need to store + // an arbitrary subset of them. + // If total == p.k, have a fast path where we don't need to handle ties. + if (total == p.k) { + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; } - } + barrier(); - uint bit_count_ex = subgroupBallotExclusiveBitCount(b); - if (top) { - // TODO: Copy directly to the output? - dst_row[out_idx + bit_count_ex] = v; + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + } else { + bool top = f2ui(intBitsToFloat(v.y)) > range_min; + bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min; + uvec4 b_top = subgroupBallot(top); + uvec4 b_eq_min = subgroupBallot(eq_min); + uint bit_count_top = subgroupBallotBitCount(b_top); + uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count_top; + eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min; + } + barrier(); + + uint out_idx = 0; + uint eq_min_base = 0; + uint eq_min_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + eq_min_idx += eq_min_partials[i]; + } + eq_min_base += offset_partials[i]; + } + // range_min values are stored at the end + eq_min_idx += eq_min_base; + + uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top); + uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex_top] = v; + } + if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) { + dst_row[eq_min_idx + bit_count_ex_eq_min] = v; + } } barrier(); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ca30537749..9d7b0152af 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -286,10 +286,11 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } -// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) -static double jdst(const int32_t * a, const int32_t * b, size_t n) { - std::unordered_map set_a; - std::unordered_map set_b; +// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap) +template +static double jdst(const T * a, const T * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; for (size_t i = 0; i < n; ++i) { set_a[a[i]]++; @@ -5001,42 +5002,94 @@ struct test_top_k : public test_case { const ggml_type type; const std::array ne; const int k; + const bool ties; + ggml_tensor * input {}; std::string vars() override { - return VARS_TO_STR3(type, ne, k); + return VARS_TO_STR4(type, ne, k, ties); } test_top_k(ggml_type type = GGML_TYPE_F32, std::array ne = {16, 10, 10, 10}, - int k = 4) - : type(type), ne(ne), k(k) {} + int k = 4, bool ties = false) + : type(type), ne(ne), k(k), ties(ties) {} double max_err() override { return 0.0; } + // When there are ties, only validate the final result. + // The logic in err can't handle the sentinel tensors. + bool run_whole_graph() override { return ties; } + double err(const float * a, const float * b, size_t n) override { - std::vector ia(n); - std::vector ib(n); + // When there are no ties, we expect the exact same set of indices, + // but possibly in a different order. When there are ties, the indices + // can be different but the input values they correspond to should be + // the same. The logic for ties could work for non-ties, but only for + // the output tensor, not for the sentinel tensors. + if (ties) { + std::vector src(ggml_nelements(input)); - double diff = 0.0f; + ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type)); - for (size_t i = 0; i < n; i++) { - ia[i] = (int32_t) a[i]; - ib[i] = (int32_t) b[i]; + double diff = 0.0f; - // penalize the result if the data is not integer valued - diff += std::fabs(a[i] - ia[i]); - diff += std::fabs(b[i] - ib[i]); + GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k)); + int64_t cols = input->ne[0]; + std::vector ia(k); + std::vector ib(k); + std::vector asrc(k); + std::vector bsrc(k); + for (int64_t r = 0; r < ggml_nrows(input); r++) { + // Convert indices for the row back to integer + for (int64_t c = 0; c < k; c++) { + ia[c] = (int32_t)a[r * k + c]; + ib[c] = (int32_t)b[r * k + c]; + } + // The src values for each row should match. + for (int64_t c = 0; c < k; c++) { + asrc[c] = src[r * cols + ia[c]]; + bsrc[c] = src[r * cols + ib[c]]; + } + diff += jdst(asrc.data(), bsrc.data(), k); + // There should be no duplicate indices + std::sort(ia.begin(), ia.end()); + std::sort(ib.begin(), ib.end()); + if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) { + diff += 1; + } + if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) { + diff += 1; + } + } + return diff; + } else { + std::vector ia(n); + std::vector ib(n); + + double diff = 0.0f; + + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; + + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } + + return diff + jdst(ia.data(), ib.data(), n); } - - return diff + jdst(ia.data(), ib.data(), n); } ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); + // Save 'a' for err() + input = a; + ggml_tensor * out = ggml_top_k(ctx, a, k); ggml_set_name(out, "out"); @@ -5047,11 +5100,16 @@ struct test_top_k : public test_case { std::random_device rd; std::default_random_engine rng(rd()); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - // initialize with unique values to avoid ties + int tie_denom = std::max(1, std::min(10, k / 2)); for (int64_t r = 0; r < ggml_nrows(t); r++) { std::vector data(t->ne[0]); for (int i = 0; i < t->ne[0]; i++) { - data[i] = i; + if (ties) { + // integer division to introduce duplicates + data[i] = i / tie_denom; + } else { + data[i] = i; + } } std::shuffle(data.begin(), data.end(), rng); ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); @@ -7657,6 +7715,7 @@ static std::vector> make_test_cases_eval() { if (k <= 1<