feat: add sinks tensor support in fa impl
This commit is contained in:
parent
aed9b4f5bb
commit
6c517f151d
|
|
@ -19,6 +19,7 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
const hexagon::tensor * k,
|
||||
const hexagon::tensor * v,
|
||||
const hexagon::tensor * mask,
|
||||
const hexagon::tensor * sinks,
|
||||
hexagon::compute_params * params) {
|
||||
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
|
||||
|
||||
|
|
@ -92,11 +93,12 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), flash_attn);
|
||||
const uint8_t * q_ptr = q->get_read_buffer();
|
||||
const uint8_t * k_ptr = k->get_read_buffer();
|
||||
const uint8_t * v_ptr = v->get_read_buffer();
|
||||
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
|
||||
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
|
||||
const uint8_t * q_ptr = q->get_read_buffer();
|
||||
const uint8_t * k_ptr = k->get_read_buffer();
|
||||
const uint8_t * v_ptr = v->get_read_buffer();
|
||||
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
|
||||
const uint8_t * sinks_ptr = sinks ? sinks->get_read_buffer() : nullptr;
|
||||
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
|
||||
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
|
||||
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
|
||||
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
|
@ -224,6 +226,22 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
}
|
||||
}
|
||||
|
||||
if (sinks_ptr) {
|
||||
const float s = reinterpret_cast<const float *>(sinks_ptr)[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hexagon::vec_scale_f32(VKQ32, ms, VKQ32, DV);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f / S;
|
||||
hexagon::vec_scale_f32(VKQ32, S_inv, VKQ32, DV);
|
||||
|
|
@ -253,20 +271,21 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
|
|||
return false;
|
||||
}
|
||||
|
||||
const auto * q = out->get_src(0);
|
||||
const auto * k = out->get_src(1);
|
||||
const auto * v = out->get_src(2);
|
||||
const auto * mask = out->get_src(3);
|
||||
if (!q || !k || !v || !mask) {
|
||||
const auto * q = out->get_src(0);
|
||||
const auto * k = out->get_src(1);
|
||||
const auto * v = out->get_src(2);
|
||||
if (!q || !k || !v) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask);
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * mask = out->get_src(3);
|
||||
const auto * sinks = out->get_src(4);
|
||||
if (k->get_type() == NPU_DATA_TYPE_F16) {
|
||||
flash_attn_impl<true>(out, q, k, v, mask, params);
|
||||
flash_attn_impl<true>(out, q, k, v, mask, sinks, params);
|
||||
} else {
|
||||
flash_attn_impl<false>(out, q, k, v, mask, params);
|
||||
flash_attn_impl<false>(out, q, k, v, mask, sinks, params);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#include "remote.idl"
|
||||
|
||||
const uint32_t DEVICE_TENSOR_MAX_DIMS = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_SRC = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_SRC = 5;
|
||||
const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 16;
|
||||
const uint32_t QUANT_BLOCK_SIZE = 32;
|
||||
const uint32_t QUANT_K_BLOCK_SIZE = 256;
|
||||
|
|
|
|||
Loading…
Reference in New Issue