Merge pull request #176 from szabadka:gemma3

PiperOrigin-RevId: 630131001
This commit is contained in:
Copybara-Service 2024-05-02 11:41:05 -07:00
commit 2a71333c8a
3 changed files with 18 additions and 99 deletions

View File

@ -402,10 +402,11 @@ struct Activations {
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
static constexpr size_t kCachePosSize = static constexpr size_t kCachePosSize =
TConfig::kGemmaLayers * kCacheLayerSize; TConfig::kGemmaLayers * kCacheLayerSize;
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out; std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen> std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
att; // attention vector att; // attention vector
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR { auto Attn = [&](float* q, uint64_t head, size_t head_offset,
size_t thread) HWY_ATTR {
// Calculate scores // Calculate scores
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT head_att = activations.att.data() + float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen + head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kQKVDim;
@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
if constexpr (kHeads == kKVHeads) { if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention // Multi-Head Attention
static_assert(TConfig::kInterleaveQKV);
float* HWY_RESTRICT qkv =
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
MatVec<kHeads * kQKVDim * 3, kModelDim>(
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
pool);
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
// linear projections to QKV float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
const size_t head_offset = TConfig::kInterleaveQKV
? 3 * kQKVDim * kModelDim
: kQKVDim * kModelDim;
const size_t mat_offset =
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
const size_t q_offset = head * head_offset + 0 * mat_offset;
const size_t k_offset = head * head_offset + 1 * mat_offset;
const size_t v_offset = head * head_offset + 2 * mat_offset;
// ProjQ
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
MatVecLoop<kQKVDim, kModelDim>(
layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x,
activations.even_odd.data() + thread * kModelDim, q);
// ProjKV
const size_t kv_offset = cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2; layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
float* HWY_RESTRICT v = k + kQKVDim;
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
k_offset, v_offset, x, k, v);
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
Attn(head, head * kQKVDim * 2, thread); memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
Attn(q, head, head * kQKVDim * 2, thread);
}); });
} else { } else {
// Multi-Query Attention // Multi-Query Attention
@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
Attn(head, 0, thread); Attn(q + head * kQKVDim, head, 0, thread);
}); });
} }

View File

@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
return kRowsPerStrip; return kRowsPerStrip;
} }
// Simple version without tiling nor threading.
// even_odd is precomputed for the current thread.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
const hn::ScalableTag<float> df;
// Sanity check: we can write without race conditions.
if (HWY_IS_TSAN) {
even_odd[0] = hwy::ConvertScalarTo<float>(vec_aligned[0]);
even_odd[kInner - 1] = -even_odd[0];
}
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
if constexpr (kAdd) {
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
Dot(df, mat, row_ofs, vec_aligned, kInner);
} else {
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
}
}
}
// even_odd is precomputed for the current thread.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
MatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out);
}
// Simple version without tiling nor threading, but two offsets/outputs. // Simple version without tiling nor threading, but two offsets/outputs.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT, template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT> typename VecT, typename AddT>
@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
} }
} }
// Simple version without tiling nor threading, but two offsets/outputs.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
const size_t mat_ofs1,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1);
}
namespace detail { namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product // For each i = [0, num_rows), compute partial (length `num_cols`) dot product

View File

@ -436,24 +436,6 @@ void TestMatVecAdd() {
AssertClose<kOuter>(actual_out, expected_out); AssertClose<kOuter>(actual_out, expected_out);
} }
void TestMatVecAddLoop() {
constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5;
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> even_odd =
hwy::AllocateAligned<float>(kInner);
hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
even_odd.get(), actual_out.get());
AssertClose<kOuter>(actual_out, expected_out);
}
void TestTwoMatVecAdd() { void TestTwoMatVecAdd() {
hwy::ThreadPool pool(0); hwy::ThreadPool pool(0);
constexpr size_t kOuter = 128 * 3; constexpr size_t kOuter = 128 * 3;
@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);