mirror of https://github.com/google/gemma.cpp.git
Merge pull request #176 from szabadka:gemma3
PiperOrigin-RevId: 630131001
This commit is contained in:
commit
2a71333c8a
|
|
@ -402,10 +402,11 @@ struct Activations {
|
|||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
||||
static constexpr size_t kCachePosSize =
|
||||
TConfig::kGemmaLayers * kCacheLayerSize;
|
||||
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
|
||||
|
||||
std::array<float, kBatchSize * kModelDim> x; // input
|
||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||
std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector
|
||||
std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
|
||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
||||
att; // attention vector
|
||||
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
||||
|
|
@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR {
|
||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
|
||||
size_t thread) HWY_ATTR {
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||
head * TConfig::kSeqLen +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
|
|
@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
if constexpr (kHeads == kKVHeads) {
|
||||
// Multi-Head Attention
|
||||
static_assert(TConfig::kInterleaveQKV);
|
||||
|
||||
float* HWY_RESTRICT qkv =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||
pool);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
const size_t head_offset = TConfig::kInterleaveQKV
|
||||
? 3 * kQKVDim * kModelDim
|
||||
: kQKVDim * kModelDim;
|
||||
const size_t mat_offset =
|
||||
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
|
||||
const size_t q_offset = head * head_offset + 0 * mat_offset;
|
||||
const size_t k_offset = head * head_offset + 1 * mat_offset;
|
||||
const size_t v_offset = head * head_offset + 2 * mat_offset;
|
||||
|
||||
// ProjQ
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||
MatVecLoop<kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data() + thread * kModelDim, q);
|
||||
|
||||
// ProjKV
|
||||
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
||||
float* HWY_RESTRICT v = k + kQKVDim;
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
k_offset, v_offset, x, k, v);
|
||||
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
||||
Attn(head, head * kQKVDim * 2, thread);
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
Attn(q, head, head * kQKVDim * 2, thread);
|
||||
});
|
||||
} else {
|
||||
// Multi-Query Attention
|
||||
|
|
@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
Attn(head, 0, thread);
|
||||
Attn(q + head * kQKVDim, head, 0, thread);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
51
gemma/ops.h
51
gemma/ops.h
|
|
@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
|||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading.
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecAddLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
// Sanity check: we can write without race conditions.
|
||||
if (HWY_IS_TSAN) {
|
||||
even_odd[0] = hwy::ConvertScalarTo<float>(vec_aligned[0]);
|
||||
even_odd[kInner - 1] = -even_odd[0];
|
||||
}
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||
if constexpr (kAdd) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
||||
Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
} else {
|
||||
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
MatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out);
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
|
|
@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
}
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||
const size_t mat_ofs1,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1) {
|
||||
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||
out0, out1);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||
|
|
|
|||
|
|
@ -436,24 +436,6 @@ void TestMatVecAdd() {
|
|||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
||||
void TestMatVecAddLoop() {
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
||||
hwy::AllocateAligned<float>(kInner);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
||||
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
||||
even_odd.get(), actual_out.get());
|
||||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
||||
void TestTwoMatVecAdd() {
|
||||
hwy::ThreadPool pool(0);
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
|
|
@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||
|
|
|
|||
Loading…
Reference in New Issue