ggml-cpu : use template for argsort (#17222)
This commit is contained in:
parent
97d5117217
commit
879dec341a
|
|
@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
|
||||||
|
|
||||||
// ggml_compute_forward_argsort
|
// ggml_compute_forward_argsort
|
||||||
|
|
||||||
|
template<enum ggml_sort_order order>
|
||||||
|
struct argsort_cmp {
|
||||||
|
const float * data;
|
||||||
|
bool operator()(int32_t a, int32_t b) const {
|
||||||
|
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
return data[a] < data[b];
|
||||||
|
} else {
|
||||||
|
return data[a] > data[b];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static void ggml_compute_forward_argsort_f32(
|
static void ggml_compute_forward_argsort_f32(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|
@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
|
||||||
dst_data[j] = j;
|
dst_data[j] = j;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<bool(int32_t, int32_t)> cmp;
|
|
||||||
|
|
||||||
// note: this might be causing memory allocations? ideally should be avoided if it's the case
|
|
||||||
switch (order) {
|
switch (order) {
|
||||||
case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
|
case GGML_SORT_ORDER_ASC:
|
||||||
case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
|
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
|
||||||
default: GGML_ABORT("invalid sort order");
|
break;
|
||||||
}
|
|
||||||
|
|
||||||
std::sort(dst_data, dst_data + ne0, cmp);
|
case GGML_SORT_ORDER_DESC:
|
||||||
|
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
GGML_ABORT("invalid sort order");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7631,6 +7631,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||||
|
|
||||||
return test_cases;
|
return test_cases;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue