mirror of https://github.com/google/gemma.cpp.git
Merge pull request #176 from szabadka:gemma3
PiperOrigin-RevId: 630131001
This commit is contained in:
commit
2a71333c8a
|
|
@ -402,10 +402,11 @@ struct Activations {
|
||||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
||||||
static constexpr size_t kCachePosSize =
|
static constexpr size_t kCachePosSize =
|
||||||
TConfig::kGemmaLayers * kCacheLayerSize;
|
TConfig::kGemmaLayers * kCacheLayerSize;
|
||||||
|
static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim;
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> x; // input
|
std::array<float, kBatchSize * kModelDim> x; // input
|
||||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||||
std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector
|
std::array<float, kBatchSize * kHeads * kQDim> q; // query vector
|
||||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
||||||
att; // attention vector
|
att; // attention vector
|
||||||
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
||||||
|
|
@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
|
|
||||||
auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR {
|
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
|
||||||
|
size_t thread) HWY_ATTR {
|
||||||
// Calculate scores
|
// Calculate scores
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
|
||||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||||
head * TConfig::kSeqLen +
|
head * TConfig::kSeqLen +
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
|
|
@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
if constexpr (kHeads == kKVHeads) {
|
||||||
// Multi-Head Attention
|
// Multi-Head Attention
|
||||||
|
static_assert(TConfig::kInterleaveQKV);
|
||||||
|
|
||||||
|
float* HWY_RESTRICT qkv =
|
||||||
|
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||||
|
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||||
|
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||||
|
pool);
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||||
// linear projections to QKV
|
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
|
||||||
const size_t head_offset = TConfig::kInterleaveQKV
|
|
||||||
? 3 * kQKVDim * kModelDim
|
|
||||||
: kQKVDim * kModelDim;
|
|
||||||
const size_t mat_offset =
|
|
||||||
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
|
|
||||||
const size_t q_offset = head * head_offset + 0 * mat_offset;
|
|
||||||
const size_t k_offset = head * head_offset + 1 * mat_offset;
|
|
||||||
const size_t v_offset = head * head_offset + 2 * mat_offset;
|
|
||||||
|
|
||||||
// ProjQ
|
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
|
||||||
MatVecLoop<kQKVDim, kModelDim>(
|
|
||||||
layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x,
|
|
||||||
activations.even_odd.data() + thread * kModelDim, q);
|
|
||||||
|
|
||||||
// ProjKV
|
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||||
float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
float* HWY_RESTRICT v = k + kQKVDim;
|
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
|
|
||||||
k_offset, v_offset, x, k, v);
|
|
||||||
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
|
||||||
|
|
||||||
Attn(head, head * kQKVDim * 2, thread);
|
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||||
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
Attn(q, head, head * kQKVDim * 2, thread);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Multi-Query Attention
|
// Multi-Query Attention
|
||||||
|
|
@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||||
Attn(head, 0, thread);
|
Attn(q + head * kQKVDim, head, 0, thread);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
51
gemma/ops.h
51
gemma/ops.h
|
|
@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
return kRowsPerStrip;
|
return kRowsPerStrip;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple version without tiling nor threading.
|
|
||||||
// even_odd is precomputed for the current thread.
|
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
|
||||||
typename VecT, typename AddT>
|
|
||||||
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
const AddT* HWY_RESTRICT add,
|
|
||||||
float* HWY_RESTRICT even_odd,
|
|
||||||
float* HWY_RESTRICT out) {
|
|
||||||
PROFILER_ZONE("MatVecAddLoop");
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
|
|
||||||
// Sanity check: we can write without race conditions.
|
|
||||||
if (HWY_IS_TSAN) {
|
|
||||||
even_odd[0] = hwy::ConvertScalarTo<float>(vec_aligned[0]);
|
|
||||||
even_odd[kInner - 1] = -even_odd[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
|
||||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
|
||||||
if constexpr (kAdd) {
|
|
||||||
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
|
||||||
Dot(df, mat, row_ofs, vec_aligned, kInner);
|
|
||||||
} else {
|
|
||||||
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// even_odd is precomputed for the current thread.
|
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
|
||||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
float* HWY_RESTRICT even_odd,
|
|
||||||
float* HWY_RESTRICT out) {
|
|
||||||
MatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
|
||||||
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
typename VecT, typename AddT>
|
typename VecT, typename AddT>
|
||||||
|
|
@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
|
||||||
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|
||||||
const size_t mat_ofs1,
|
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
float* HWY_RESTRICT out0,
|
|
||||||
float* HWY_RESTRICT out1) {
|
|
||||||
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
|
||||||
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
|
||||||
out0, out1);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||||
|
|
|
||||||
|
|
@ -436,24 +436,6 @@ void TestMatVecAdd() {
|
||||||
AssertClose<kOuter>(actual_out, expected_out);
|
AssertClose<kOuter>(actual_out, expected_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatVecAddLoop() {
|
|
||||||
constexpr size_t kOuter = 128 * 3;
|
|
||||||
constexpr size_t kInner = 128 * 5;
|
|
||||||
CompressedArray<float, kOuter * kInner> mat = GenerateMat<kOuter, kInner>(0);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
|
||||||
hwy::AllocateAligned<float>(kInner);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
|
||||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
|
||||||
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
|
||||||
MatVecAddLoop<true, kOuter, kInner>(mat, 0, vec.get(), add.get(),
|
|
||||||
even_odd.get(), actual_out.get());
|
|
||||||
AssertClose<kOuter>(actual_out, expected_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestTwoMatVecAdd() {
|
void TestTwoMatVecAdd() {
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
constexpr size_t kOuter = 128 * 3;
|
constexpr size_t kOuter = 128 * 3;
|
||||||
|
|
@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue