mirror of https://github.com/google/gemma.cpp.git
De-templatize Activations, add RowVectorBatch class
Also remove most kBatchSize args. PiperOrigin-RevId: 653185525
This commit is contained in:
parent
ff34370aac
commit
992a2cbbc0
|
|
@ -18,77 +18,130 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <array>
|
#include "gemma/common.h" // kMaxThreads - TODO: remove
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "gemma/common.h" // AllocateSizeof
|
#include "hwy/base.h" // HWY_DASSERT
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Must be aligned.
|
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||||
template <class TConfig, size_t kBatchSize>
|
// This can be seen as a (batch_size x len) matrix.
|
||||||
|
template <typename T>
|
||||||
|
class RowVectorBatch {
|
||||||
|
public:
|
||||||
|
// Default ctor for Activations ctor.
|
||||||
|
RowVectorBatch() : batch_size_(0), len_(0) {}
|
||||||
|
// Main ctor, called from Activations::Allocate.
|
||||||
|
RowVectorBatch(size_t batch_size, size_t len)
|
||||||
|
: batch_size_(batch_size), len_(len) {
|
||||||
|
mem_ = hwy::AllocateAligned<T>(batch_size * len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move-only
|
||||||
|
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||||
|
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||||
|
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||||
|
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||||
|
|
||||||
|
size_t BatchSize() const { return batch_size_; }
|
||||||
|
size_t Len() const { return len_; }
|
||||||
|
|
||||||
|
// Returns the given row vector of length `Len()`.
|
||||||
|
T* Batch(size_t batch_idx) {
|
||||||
|
HWY_DASSERT(batch_idx < batch_size_);
|
||||||
|
return mem_.get() + batch_idx * len_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For MatMul or other operations that process the entire batch at once.
|
||||||
|
T* All() { return mem_.get(); }
|
||||||
|
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||||
|
size_t batch_size_; // rows in the matrix
|
||||||
|
size_t len_; // columns in the matrix = vector length
|
||||||
|
};
|
||||||
|
|
||||||
struct Activations {
|
struct Activations {
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
RowVectorBatch<float> x; // input
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
RowVectorBatch<float> q; // query, also KV if MHA.
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
RowVectorBatch<float> logits;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
|
||||||
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
|
|
||||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
|
||||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
|
||||||
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> x; // input
|
// Attention
|
||||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
RowVectorBatch<float> pre_att_rms_out;
|
||||||
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
|
RowVectorBatch<float> att; // attention vector
|
||||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
RowVectorBatch<float> att_out; // attention output
|
||||||
att; // attention vector
|
// After linear transformation, shared by all heads
|
||||||
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
|
RowVectorBatch<float> att_post1;
|
||||||
std::array<float, kHeads * kBatchSize * kModelDim>
|
// Accumulation of attention outputs over heads
|
||||||
att_post1; // attention output after linear transformation, per head
|
RowVectorBatch<float> att_post2;
|
||||||
std::array<float, kBatchSize * kModelDim>
|
|
||||||
att_post2; // accumulation of attention outputs over heads
|
|
||||||
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
|
|
||||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
|
|
||||||
|
|
||||||
// For FFW MatMul.
|
// Gated FFW
|
||||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
|
RowVectorBatch<hwy::bfloat16_t> bf_pre_ffw_rms_out;
|
||||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
|
RowVectorBatch<float> C1;
|
||||||
|
RowVectorBatch<float> C2;
|
||||||
|
RowVectorBatch<float> ffw_out;
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
// Griffin
|
||||||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
RowVectorBatch<float> griffin_x;
|
||||||
|
RowVectorBatch<float> griffin_y;
|
||||||
|
RowVectorBatch<float> griffin_gate_x;
|
||||||
|
RowVectorBatch<float> griffin_multiplier;
|
||||||
|
|
||||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
||||||
// per-thread storage.
|
// per-thread storage.
|
||||||
// TODO: only used for MatVec, remove once that is gone.
|
// TODO: remove once MatVec is gone.
|
||||||
std::array<float, kModelDim * kMaxThreads> even_odd;
|
RowVectorBatch<float> even_odd;
|
||||||
|
|
||||||
// Griffin layer internal activations
|
// Multi-Head Attention?
|
||||||
static constexpr size_t kGriffinDim =
|
template <class TConfig>
|
||||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
static constexpr bool IsMHA() {
|
||||||
std::array<float, kBatchSize * kGriffinDim> griffin_x;
|
return TConfig::kHeads == TConfig::kKVHeads;
|
||||||
std::array<float, kBatchSize * kGriffinDim> griffin_y;
|
}
|
||||||
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
|
|
||||||
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TConfig>
|
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||||
struct AllocateState {
|
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||||
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
template <class TConfig>
|
||||||
// When batching queries, the prefill batch size is reduced by a factor
|
static constexpr size_t QStride() {
|
||||||
// of kBatchedQueryBatchSize
|
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
|
||||||
prefill =
|
}
|
||||||
AllocateSizeof<Activations<TConfig, kMinAdjustedPrefillBatchSize *
|
|
||||||
kBatchedQueryBatchSize>>();
|
template <class TConfig>
|
||||||
decode = AllocateSizeof<
|
void Allocate(size_t batch_size) {
|
||||||
Activations<TConfig, kDecodeBatchSize * kBatchedQueryBatchSize>>();
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
|
constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
|
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
|
constexpr size_t kGriffinLayers = TConfig::kGriffinLayers;
|
||||||
|
|
||||||
|
x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
q = RowVectorBatch<float>(batch_size, kHeads * QStride<TConfig>());
|
||||||
|
logits = RowVectorBatch<float>(batch_size, kVocabSize);
|
||||||
|
|
||||||
|
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen);
|
||||||
|
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
|
||||||
|
att_post1 = RowVectorBatch<float>(1, kModelDim);
|
||||||
|
att_post2 = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
|
||||||
|
bf_pre_ffw_rms_out = RowVectorBatch<hwy::bfloat16_t>(batch_size, kModelDim);
|
||||||
|
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||||
|
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||||
|
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
|
||||||
|
if (kGriffinLayers > 0) {
|
||||||
|
griffin_x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
griffin_y = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
griffin_gate_x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
}
|
||||||
|
|
||||||
|
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
|
||||||
Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
|
|
||||||
return *reinterpret_cast<Activations<TConfig, kBatchSize>*>(state_u8.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
|
|
@ -59,33 +60,27 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void GriffinRecurrent(
|
HWY_NOINLINE void GriffinRecurrent(
|
||||||
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
||||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
static_assert(kQueryBatchSize == 1,
|
|
||||||
"Griffin does not support batched queries.");
|
|
||||||
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
||||||
KVCache& kv_cache = *kv_caches[0];
|
KVCache& kv_cache = *kv_caches[0];
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
HWY_ASSERT(num_tokens <= kBatchSize);
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kModelDim =
|
|
||||||
gcpp::Activations<TConfig, kBatchSize * kQueryBatchSize>::kModelDim;
|
|
||||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
|
||||||
// X / Y linear layers.
|
// X / Y linear layers.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
|
||||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
|
||||||
TwoMatVecAdd<kModelDim, kModelDim>(
|
TwoMatVecAdd<kModelDim, kModelDim>(
|
||||||
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
||||||
activations.pre_att_rms_out.data() + batch_offset,
|
activations.pre_att_rms_out.Batch(batch_idx),
|
||||||
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
||||||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
||||||
/*out1=*/y, pool);
|
/*out1=*/y, pool);
|
||||||
|
|
@ -94,9 +89,8 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
|
|
||||||
// Conv1D.
|
// Conv1D.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||||
HWY_FULL(float) df;
|
HWY_FULL(float) df;
|
||||||
HWY_DASSERT(kModelDim % hn::Lanes(df) == 0);
|
HWY_DASSERT(kModelDim % hn::Lanes(df) == 0);
|
||||||
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||||
|
|
@ -130,14 +124,11 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
|
|
||||||
// RGLRU
|
// RGLRU
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||||
float* HWY_RESTRICT gate_x =
|
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx);
|
||||||
activations.griffin_gate_x.data() + batch_offset;
|
float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx);
|
||||||
float* HWY_RESTRICT a =
|
|
||||||
activations.griffin_multiplier.data() + batch_offset;
|
|
||||||
float* HWY_RESTRICT rnn_state =
|
float* HWY_RESTRICT rnn_state =
|
||||||
kv_cache.rglru_cache.get() + layer * kModelDim;
|
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||||
|
|
||||||
|
|
@ -185,13 +176,12 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
|
|
||||||
// Final linear layer.
|
// Final linear layer.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
float* out_ptr = activations.att_post2.Batch(batch_idx);
|
||||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
|
||||||
MatVecAdd<kModelDim, kModelDim>(
|
MatVecAdd<kModelDim, kModelDim>(
|
||||||
layer_weights->griffin.linear_out_w, 0, x,
|
layer_weights->griffin.linear_out_w, 0, x,
|
||||||
layer_weights->griffin.linear_out_biases.data(),
|
layer_weights->griffin.linear_out_biases.data(),
|
||||||
activations.even_odd.data(), out_ptr, pool);
|
activations.even_odd.All(), out_ptr, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -202,30 +192,26 @@ HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) {
|
||||||
Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Attention(
|
HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
size_t batch_and_query_start, size_t num_tokens, size_t num_queries,
|
size_t num_queries, size_t layer,
|
||||||
size_t layer,
|
Activations& activations,
|
||||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const std::vector<KVCache*>& kv_caches,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
|
||||||
HWY_DASSERT(batch_and_query_start % num_queries == 0);
|
HWY_DASSERT(batch_and_query_start % num_queries == 0);
|
||||||
using TActivations = Activations<TConfig, kBatchSize * kQueryBatchSize>;
|
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
constexpr size_t kQKVDim = TActivations::kQKVDim;
|
constexpr size_t kQStride = Activations::QStride<TConfig>();
|
||||||
constexpr size_t kQStride = TActivations::kQStride;
|
|
||||||
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
||||||
constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
|
constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
|
||||||
constexpr size_t kModelDim = TActivations::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
constexpr size_t kHeads = TConfig::kHeads;
|
constexpr size_t kHeads = TConfig::kHeads;
|
||||||
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||||
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||||
constexpr bool kIsMHA = TActivations::kIsMHA;
|
constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
||||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||||
const size_t batch_start = batch_and_query_start / num_queries;
|
const size_t batch_start = batch_and_query_start / num_queries;
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||||
|
|
@ -237,15 +223,14 @@ HWY_NOINLINE void Attention(
|
||||||
// Compute Q only or QKV (if MHA).
|
// Compute Q only or QKV (if MHA).
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||||
num_tokens_and_queries, activations.pre_att_rms_out.data(),
|
num_tokens_and_queries, activations.pre_att_rms_out.All(),
|
||||||
layer_weights->qkv_einsum_w.data(), activations.q.data(), pool);
|
layer_weights->qkv_einsum_w.data(), activations.q.All(), pool);
|
||||||
|
|
||||||
// Compute KV if not MHA.
|
// Compute KV if not MHA.
|
||||||
if constexpr (!kIsMHA) {
|
if constexpr (!kIsMHA) {
|
||||||
for (size_t batch_and_query_idx = 0;
|
for (size_t batch_and_query_idx = 0;
|
||||||
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) {
|
||||||
const float* x =
|
const float* x = activations.pre_att_rms_out.Batch(batch_and_query_idx);
|
||||||
activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim;
|
|
||||||
const size_t query_idx = batch_and_query_idx % num_queries;
|
const size_t query_idx = batch_and_query_idx % num_queries;
|
||||||
const size_t batch_idx = batch_and_query_idx / num_queries;
|
const size_t batch_idx = batch_and_query_idx / num_queries;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
|
|
@ -258,7 +243,7 @@ HWY_NOINLINE void Attention(
|
||||||
// TODO: requires MatMul support for offsets.
|
// TODO: requires MatMul support for offsets.
|
||||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||||
activations.even_odd.data(), kv, pool);
|
activations.even_odd.All(), kv, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -279,8 +264,7 @@ HWY_NOINLINE void Attention(
|
||||||
if constexpr (kIsMHA) {
|
if constexpr (kIsMHA) {
|
||||||
// For MHA, copy KV into the KV cache from scratch space (see above).
|
// For MHA, copy KV into the KV cache from scratch space (see above).
|
||||||
const float* HWY_RESTRICT q =
|
const float* HWY_RESTRICT q =
|
||||||
activations.q.data() + (batch_and_query_idx * kHeads
|
activations.q.Batch(batch_and_query_idx) + head * kQStride;
|
||||||
+ head) * kQStride;
|
|
||||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
@ -300,7 +284,7 @@ HWY_NOINLINE void Attention(
|
||||||
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride;
|
activations.q.Batch(batch_and_query_idx) + head * kQStride;
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
|
@ -309,11 +293,7 @@ HWY_NOINLINE void Attention(
|
||||||
|
|
||||||
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations.att.data() + head * kSeqLen
|
activations.att.Batch(batch_and_query_idx) + head * kSeqLen;
|
||||||
+ batch_and_query_idx * kHeads * kSeqLen;
|
|
||||||
|
|
||||||
|
|
||||||
// Compute Q dot K scores
|
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
|
|
@ -335,8 +315,8 @@ HWY_NOINLINE void Attention(
|
||||||
// Summation of v (kv_cache) weighted by probs (head_att)
|
// Summation of v (kv_cache) weighted by probs (head_att)
|
||||||
// into "encoded" (att_out). Compare gemma/modules.py:
|
// into "encoded" (att_out). Compare gemma/modules.py:
|
||||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
float* HWY_RESTRICT att_out =
|
||||||
batch_and_query_idx * kHeads * kQKVDim;
|
activations.att_out.Batch(batch_and_query_idx) + head * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
|
|
@ -355,26 +335,24 @@ HWY_NOINLINE void Attention(
|
||||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
// rearranging the weights.
|
// rearranging the weights.
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim;
|
activations.att_out.Batch(batch_and_query_idx);
|
||||||
float* HWY_RESTRICT layer_out =
|
float* HWY_RESTRICT layer_out =
|
||||||
activations.att_post2.data() + batch_and_query_idx * kModelDim;
|
activations.att_post2.Batch(batch_and_query_idx);
|
||||||
// Head 0 (and potentially biases) -> layer_out.
|
// Head 0 (and potentially biases) -> layer_out.
|
||||||
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||||
layer_weights->attention_output_biases.data(),
|
layer_weights->attention_output_biases.data(),
|
||||||
activations.even_odd.data(), layer_out, pool);
|
activations.even_odd.All(), layer_out, pool);
|
||||||
// Head 1 and following are added to layer_out.
|
// Head 1 and following are added to layer_out.
|
||||||
for (size_t head = 1; head < kHeads; ++head) {
|
for (size_t head = 1; head < kHeads; ++head) {
|
||||||
// TODO(patrickms): Check this calculation
|
// NOTE: this is a single kModelDim temp output. If parallelized or using
|
||||||
float* HWY_RESTRICT head_out =
|
// MatMul, add per-thread storage.
|
||||||
activations.att_post1.data() +
|
float* HWY_RESTRICT head_out = activations.att_post1.All();
|
||||||
head * kBatchSize * kQueryBatchSize * kModelDim;
|
|
||||||
// TODO: requires MatMul support for offsets.
|
// TODO: requires MatMul support for offsets.
|
||||||
MatVec<kModelDim, kQKVDim>(
|
MatVec<kModelDim, kQKVDim>(
|
||||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||||
att_out + head * kQKVDim,
|
att_out + head * kQKVDim, activations.even_odd.All(), head_out, pool);
|
||||||
activations.even_odd.data(), head_out, pool);
|
|
||||||
AddFrom(head_out, layer_out, kModelDim);
|
AddFrom(head_out, layer_out, kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -392,13 +370,11 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
|
||||||
size_t num_tokens,
|
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.FFW");
|
PROFILER_ZONE("Gen.FFW");
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
|
|
||||||
|
|
@ -406,11 +382,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
// elements in memory, repeated kFFHiddenDim times.
|
// elements in memory, repeated kFFHiddenDim times.
|
||||||
constexpr size_t kColsA = kModelDim;
|
constexpr size_t kColsA = kModelDim;
|
||||||
constexpr size_t kColsB = kFFHiddenDim;
|
constexpr size_t kColsB = kFFHiddenDim;
|
||||||
const auto A = activations.bf_pre_ffw_rms_out.data();
|
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||||
|
const auto A = activations.bf_pre_ffw_rms_out.All();
|
||||||
const auto B1 = layer_weights->gating_einsum_w.data();
|
const auto B1 = layer_weights->gating_einsum_w.data();
|
||||||
const auto B2 = B1 + kColsA * kColsB;
|
const auto B2 = B1 + kColsA * kColsB;
|
||||||
auto C1 = activations.C1.data();
|
auto C1 = activations.C1.All();
|
||||||
auto C2 = activations.C2.data();
|
auto C2 = activations.C2.All();
|
||||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||||
const auto bias = layer_weights->ffw_gating_biases.data();
|
const auto bias = layer_weights->ffw_gating_biases.data();
|
||||||
|
|
||||||
|
|
@ -422,8 +399,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
bias + kFFHiddenDim, pool);
|
bias + kFFHiddenDim, pool);
|
||||||
|
|
||||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||||
Activation<TConfig>(activations.C1.data(), activations.C2.data(),
|
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
|
||||||
kFFHiddenDim * num_tokens);
|
|
||||||
|
|
||||||
// linear_w may have a scale value different from 1, apply that here.
|
// linear_w may have a scale value different from 1, apply that here.
|
||||||
// We multiply all activations by the scale value to compensate for the
|
// We multiply all activations by the scale value to compensate for the
|
||||||
|
|
@ -434,123 +410,115 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
||||||
num_tokens, C1, layer_weights->linear_w.data(),
|
num_tokens, C1, layer_weights->linear_w.data(), activations.ffw_out.All(),
|
||||||
activations.ffw_out.data(), layer_weights->ffw_output_biases.data(),
|
layer_weights->ffw_output_biases.data(), pool);
|
||||||
pool);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
// TODO: pass Activations.x instead of Activations.
|
||||||
HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
template <class TConfig>
|
||||||
|
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations) {
|
Activations& activations) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
HWY_DASSERT(token >= 0);
|
HWY_DASSERT(token >= 0);
|
||||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.Batch(batch_idx), kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
|
||||||
kModelDim);
|
|
||||||
if constexpr (TConfig::kAbsolutePE) {
|
if constexpr (TConfig::kAbsolutePE) {
|
||||||
AddAbsolutePositionalEmbeddings(
|
AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim,
|
pos + batch_idx);
|
||||||
pos + token_idx);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize, typename T>
|
template <class TConfig, typename T>
|
||||||
HWY_NOINLINE void ResidualConnection(
|
HWY_NOINLINE void ResidualConnection(
|
||||||
size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
||||||
const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
|
const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
// ResidualType::Add
|
// ResidualType::Add
|
||||||
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries, other, x,
|
AddFromBatched(num_tokens_and_queries, other, x, kModelDim);
|
||||||
kModelDim);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
template <class TConfig, size_t kQueryBatchSize>
|
||||||
HWY_NOINLINE void TransformerLayer(
|
HWY_NOINLINE void TransformerLayer(
|
||||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||||
const CompressedLayer<TConfig>* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
RMSNormBatched<kBatchSize * kQueryBatchSize>(
|
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
||||||
num_tokens_and_queries, activations.x.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
activations.pre_att_rms_out.All(), kModelDim);
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<TConfig, kBatchSize, kQueryBatchSize>(
|
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
|
||||||
pos, num_tokens, num_queries, layer_of_type, activations,
|
layer_weights, kv_caches, pool);
|
||||||
layer_weights, kv_caches, pool);
|
|
||||||
} else {
|
} else {
|
||||||
// This Griffin layers should never exist unless the model is a Griffin
|
// This Griffin layers should never exist unless the model is a Griffin
|
||||||
// model. This conditional prevents the compiler from generating code for
|
// model. This conditional prevents the compiler from generating code for
|
||||||
// this branch when building a non-Griffin model, since we have static
|
// this branch when building a non-Griffin model, since we have static
|
||||||
// asserts about the query batch size for Griffin layers.
|
// asserts about the query batch size for Griffin layers.
|
||||||
if constexpr (TConfig::kGriffinLayers > 0) {
|
if constexpr (TConfig::kGriffinLayers > 0) {
|
||||||
GriffinRecurrent<TConfig, kBatchSize, kQueryBatchSize>(
|
static_assert(kQueryBatchSize == 1,
|
||||||
pos, num_tokens, num_queries, layer_of_type, activations,
|
"Griffin does not support batched queries.");
|
||||||
layer_weights, kv_caches, pool);
|
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
|
||||||
|
activations, layer_weights, kv_caches, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
RMSNormInplaceBatched(num_tokens_and_queries,
|
||||||
num_tokens_and_queries,
|
layer_weights->post_attention_norm_scale.data(),
|
||||||
layer_weights->post_attention_norm_scale.data(),
|
activations.att_post2.All(), kModelDim);
|
||||||
activations.att_post2.data(), kModelDim);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ResidualConnection<TConfig, kBatchSize, kQueryBatchSize>(
|
ResidualConnection<TConfig>(num_tokens_and_queries,
|
||||||
num_tokens_and_queries, activations.att_post2.data(),
|
activations.att_post2.All(), activations.x.All(),
|
||||||
activations.x.data(), layer_weights, /*is_attention*/ true);
|
layer_weights, /*is_attention=*/true);
|
||||||
RMSNormBatched<kBatchSize * kQueryBatchSize>(
|
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
||||||
num_tokens_and_queries, activations.x.data(),
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
|
||||||
FFW<TConfig, kBatchSize * kQueryBatchSize>(
|
|
||||||
activations, num_tokens_and_queries, layer_weights, pool);
|
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
RMSNormInplaceBatched(num_tokens_and_queries,
|
||||||
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
|
layer_weights->post_ffw_norm_scale.data(),
|
||||||
activations.ffw_out.data(), kModelDim);
|
activations.ffw_out.All(), kModelDim);
|
||||||
}
|
}
|
||||||
ResidualConnection<TConfig, kBatchSize, kQueryBatchSize>(
|
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
|
||||||
num_tokens_and_queries, activations.ffw_out.data(), activations.x.data(),
|
activations.x.All(), layer_weights,
|
||||||
layer_weights, /*is_attention*/ false);
|
/*is_attention=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||||
HWY_NOINLINE void Prefill(
|
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens,
|
||||||
const int* tokens, size_t num_tokens, size_t num_queries, size_t pos,
|
size_t num_queries, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
Activations& activations,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
|
const std::vector<KVCache*>& kv_caches,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
||||||
const size_t minibatch_size = std::min(num_tokens, kBatchSize);
|
const size_t minibatch_size = std::min(num_tokens, kBatchSize);
|
||||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
// TODO: hoist pool.Run out of the loop, change the unit of work to batches.
|
||||||
// TODO(patrickms): Try to hoist pool.Run out of the loop.
|
|
||||||
for (size_t i = 0; i < num_tokens; i += minibatch_size) {
|
for (size_t i = 0; i < num_tokens; i += minibatch_size) {
|
||||||
const size_t offset = i * num_queries;
|
const size_t offset = i * num_queries;
|
||||||
const size_t current_token_count = std::min(
|
const size_t current_token_count = std::min(
|
||||||
minibatch_size, num_tokens - i);
|
minibatch_size, num_tokens - i);
|
||||||
pool.Run(0, current_token_count * num_queries,
|
pool.Run(0, current_token_count * num_queries,
|
||||||
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
EmbedToken<TConfig, kBatchSize * kQueryBatchSize>(
|
EmbedToken<TConfig>(tokens[token_idx + offset], token_idx,
|
||||||
tokens[token_idx + offset], token_idx, pos + offset,
|
pos + offset, weights, activations);
|
||||||
weights, activations);
|
});
|
||||||
});
|
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer<TConfig, kBatchSize, kQueryBatchSize>(
|
TransformerLayer<TConfig, kQueryBatchSize>(
|
||||||
current_token_count, num_queries, pos + offset , layer, layer_weights,
|
current_token_count, num_queries, pos + offset, layer, layer_weights,
|
||||||
activations, kv_caches, pool);
|
activations, kv_caches, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -558,15 +526,14 @@ HWY_NOINLINE void Prefill(
|
||||||
|
|
||||||
// Compute the transformer for a batch of input tokens. During generation,
|
// Compute the transformer for a batch of input tokens. During generation,
|
||||||
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
template <class TConfig, size_t kQueryBatchSize>
|
||||||
HWY_NOINLINE void Transformer(
|
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
const int* tokens, size_t num_tokens, size_t num_queries, size_t pos,
|
size_t num_queries, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations<TConfig, kBatchSize * kQueryBatchSize>& activations,
|
Activations& activations,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const std::vector<KVCache*>& kv_caches,
|
||||||
hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool,
|
||||||
const LayersOutputFunc& layers_output) {
|
const LayersOutputFunc& layers_output) {
|
||||||
HWY_ASSERT(num_tokens <= kBatchSize);
|
|
||||||
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
const size_t num_tokens_and_queries = num_tokens * num_queries;
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||||
|
|
@ -577,34 +544,33 @@ HWY_NOINLINE void Transformer(
|
||||||
}
|
}
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) {
|
||||||
EmbedToken<TConfig, kBatchSize * kQueryBatchSize>(
|
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
||||||
tokens[token_idx], token_idx, pos, weights, activations);
|
activations);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer<TConfig, kBatchSize, kQueryBatchSize>(
|
TransformerLayer<TConfig, kQueryBatchSize>(num_tokens, num_queries, pos,
|
||||||
num_tokens, num_queries, pos, layer, layer_weights,
|
layer, layer_weights,
|
||||||
activations, kv_caches, pool);
|
activations, kv_caches, pool);
|
||||||
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
const std::string block_name = "blocks." + std::to_string(layer);
|
const std::string block_name = "blocks." + std::to_string(layer);
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||||
++token_idx) {
|
++token_idx) {
|
||||||
layers_output(pos + token_idx, block_name,
|
layers_output(pos + token_idx, block_name,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.Batch(token_idx), kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
RMSNormInplaceBatched(num_tokens_and_queries, weights.final_norm_scale.data(),
|
||||||
num_tokens * num_queries, weights.final_norm_scale.data(),
|
activations.x.All(), kModelDim);
|
||||||
activations.x.data(), kModelDim);
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||||
++token_idx) {
|
++token_idx) {
|
||||||
layers_output(pos + token_idx, "final_norm",
|
layers_output(pos + token_idx, "final_norm",
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.Batch(token_idx), kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -644,9 +610,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
// Placeholder for internal test3, do not remove
|
// Placeholder for internal test3, do not remove
|
||||||
|
|
||||||
template <class TConfig, size_t kQueryBatchSize>
|
template <class TConfig, size_t kQueryBatchSize>
|
||||||
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
void GenerateT(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& decode_u8,
|
Activations& activations, const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
const size_t query_index_offset,
|
const size_t query_index_offset,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||||
|
|
@ -659,10 +624,6 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
pos *= num_queries; // position in (num_queries) interleaved token sequence.
|
pos *= num_queries; // position in (num_queries) interleaved token sequence.
|
||||||
const CompressedWeights<TConfig>& weights =
|
const CompressedWeights<TConfig>& weights =
|
||||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
auto& prefill_activations =
|
|
||||||
GetActivations<TConfig,
|
|
||||||
kAdjustedPrefillBatchSize * kQueryBatchSize>(prefill_u8);
|
|
||||||
auto& activations = GetActivations<TConfig, kQueryBatchSize>(decode_u8);
|
|
||||||
|
|
||||||
size_t min_prompt_size = (size_t)-1;
|
size_t min_prompt_size = (size_t)-1;
|
||||||
size_t max_prompt_size = 0;
|
size_t max_prompt_size = 0;
|
||||||
|
|
@ -735,8 +696,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries);
|
HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries);
|
||||||
const int* batch_tokens = prompt.data() + pos_offset;
|
const int* batch_tokens = prompt.data() + pos_offset;
|
||||||
Prefill<TConfig, kAdjustedPrefillBatchSize, kQueryBatchSize>(
|
Prefill<TConfig, kAdjustedPrefillBatchSize, kQueryBatchSize>(
|
||||||
batch_tokens, batch_size, num_queries, pos, weights,
|
batch_tokens, batch_size, num_queries, pos, weights, prefill, kv_caches,
|
||||||
prefill_activations, kv_caches, pool);
|
pool);
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
bool all_tokens_eos = true;
|
bool all_tokens_eos = true;
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
|
|
@ -788,7 +749,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
for (size_t generate_pos = 0;
|
for (size_t generate_pos = 0;
|
||||||
generate_pos < max_tokens && generate_pos < max_generated_tokens;
|
generate_pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++single_prompt_pos_offset, ++generate_pos) {
|
++single_prompt_pos_offset, ++generate_pos) {
|
||||||
Transformer<TConfig, kDecodeBatchSize, kQueryBatchSize>(
|
Transformer<TConfig, kQueryBatchSize>(
|
||||||
gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights,
|
gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights,
|
||||||
activations, kv_caches, pool, runtime_config.layers_output);
|
activations, kv_caches, pool, runtime_config.layers_output);
|
||||||
float token_logit = 0.0f;
|
float token_logit = 0.0f;
|
||||||
|
|
@ -796,22 +757,20 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
// is disabled.
|
// is disabled.
|
||||||
bool all_tokens_eos = true;
|
bool all_tokens_eos = true;
|
||||||
float* x = activations.x.data();
|
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset) {
|
||||||
float* logits = activations.logits.data();
|
const float* HWY_RESTRICT x = activations.x.Batch(i);
|
||||||
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset,
|
float* HWY_RESTRICT logits = activations.logits.Batch(i);
|
||||||
x += TConfig::kModelDim, logits += kVocabSize) {
|
|
||||||
const size_t prompt_size = prompts[i].size();
|
const size_t prompt_size = prompts[i].size();
|
||||||
const bool is_generating_phase =
|
const bool is_generating_phase =
|
||||||
(single_prompt_pos_offset >= prompt_size - 1);
|
(single_prompt_pos_offset >= prompt_size - 1);
|
||||||
if (is_generating_phase) {
|
if (is_generating_phase) {
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
||||||
weights.embedder_input_embedding, 0, x, activations.even_odd.data(),
|
0, x, activations.even_odd.All(),
|
||||||
logits, pool);
|
logits, pool);
|
||||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(),
|
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||||
kVocabSize);
|
|
||||||
}
|
}
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(logits, kVocabSize);
|
Softmax(logits, kVocabSize);
|
||||||
|
|
@ -850,9 +809,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateSingleT(const ByteStorageT& weights_u8,
|
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& prefill_u8,
|
Activations& activations,
|
||||||
const ByteStorageT& decode_u8,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos,
|
const std::vector<int>& prompt, size_t pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
|
@ -865,19 +823,17 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
|
||||||
std::vector<KVCache*> kv_caches = {&kv_cache};
|
std::vector<KVCache*> kv_caches = {&kv_cache};
|
||||||
const size_t query_index_offset = 0;
|
const size_t query_index_offset = 0;
|
||||||
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
|
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
|
||||||
weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos,
|
weights_u8, prefill, activations, runtime_config, prompts, pos,
|
||||||
query_index_offset, kv_caches, pool, timing_info);
|
query_index_offset, kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateBatchT(const ByteStorageT& weights_u8,
|
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& prefill_u8,
|
Activations& activations,
|
||||||
const ByteStorageT& decode_u8,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
size_t pos, const std::vector<KVCache*>& kv_caches,
|
const std::vector<KVCache*>& kv_caches,
|
||||||
hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
|
||||||
// Disable query batching for Griffin models.
|
// Disable query batching for Griffin models.
|
||||||
constexpr size_t kQueryBatchSize =
|
constexpr size_t kQueryBatchSize =
|
||||||
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
||||||
|
|
@ -885,9 +841,9 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||||
const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize);
|
const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize);
|
||||||
const hwy::Span<const hwy::Span<int>> current_prompts(
|
const hwy::Span<const hwy::Span<int>> current_prompts(
|
||||||
prompts.data() + i, num_queries);
|
prompts.data() + i, num_queries);
|
||||||
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill_u8, decode_u8,
|
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill, activations,
|
||||||
runtime_config, current_prompts,
|
runtime_config, current_prompts, pos, i,
|
||||||
pos, i, kv_caches, pool, timing_info);
|
kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -898,25 +854,23 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||||
// These are extern functions defined by instantiations/*.cc, which include this
|
// These are extern functions defined by instantiations/*.cc, which include this
|
||||||
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
||||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
|
Activations& activations, const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config, const std::vector<int>& prompt,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
||||||
(weights_u8, prefill_u8, decode_u8, runtime_config, prompt, pos, kv_cache,
|
(weights_u8, prefill, activations, runtime_config, prompt, pos, kv_cache,
|
||||||
pool, timing_info);
|
pool, timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
|
Activations& activations, const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||||
(weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos, kv_caches,
|
(weights_u8, prefill, activations, runtime_config, prompts, pos, kv_caches,
|
||||||
pool, timing_info);
|
pool, timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,23 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
struct AllocateState {
|
||||||
|
void operator()(Activations& prefill, Activations& decode) const {
|
||||||
|
// When batching queries, the prefill batch size is reduced by a factor
|
||||||
|
// of kBatchedQueryBatchSize
|
||||||
|
prefill.Allocate<TConfig>(kMinAdjustedPrefillBatchSize *
|
||||||
|
kBatchedQueryBatchSize);
|
||||||
|
decode.Allocate<TConfig>(kDecodeBatchSize * kBatchedQueryBatchSize);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||||
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
||||||
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
|
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
|
||||||
decode_u8_);
|
decode_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
|
|
@ -50,8 +61,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
HWY_ASSERT(info.weight == Type::kF32);
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
weights_u8_ =
|
weights_u8_ =
|
||||||
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
||||||
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
|
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
|
||||||
decode_u8_);
|
decode_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
|
|
@ -63,19 +74,17 @@ Gemma::~Gemma() {
|
||||||
// we shard them across multiple translation units in instantiations/*.cc.
|
// we shard them across multiple translation units in instantiations/*.cc.
|
||||||
// This declares the functions defined there. We use overloading because
|
// This declares the functions defined there. We use overloading because
|
||||||
// explicit instantiations are still too slow to compile.
|
// explicit instantiations are still too slow to compile.
|
||||||
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
||||||
extern void GenerateSingle( \
|
extern void GenerateSingle( \
|
||||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
|
||||||
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \
|
Activations& decode, const RuntimeConfig& runtime_config, \
|
||||||
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, \
|
||||||
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \
|
hwy::ThreadPool& pool, TimingInfo& timing_info); \
|
||||||
TimingInfo& timing_info); \
|
extern void GenerateBatch( \
|
||||||
extern void GenerateBatch( \
|
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
|
||||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
Activations& decode, const RuntimeConfig& runtime_config, \
|
||||||
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
|
||||||
const RuntimeConfig& runtime_config, \
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
|
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
|
|
||||||
TimingInfo& timing_info);
|
TimingInfo& timing_info);
|
||||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||||
|
|
||||||
|
|
@ -83,25 +92,23 @@ GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||||
// TODO: gather all ByteStorageT into a type-erased model struct?
|
// TODO: gather all ByteStorageT into a type-erased model struct?
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct GenerateSingleT {
|
struct GenerateSingleT {
|
||||||
void operator()(const ByteStorageT& weights_u8,
|
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
|
Activations& decode, const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
||||||
GenerateSingle(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config,
|
GenerateSingle(TConfig(), weights_u8, prefill, decode, runtime_config,
|
||||||
prompt, pos, kv_cache, pool, timing_info);
|
prompt, pos, kv_cache, pool, timing_info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct GenerateBatchT {
|
struct GenerateBatchT {
|
||||||
void operator()(const ByteStorageT& weights_u8,
|
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8,
|
Activations& decode, const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config,
|
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info) const {
|
TimingInfo& timing_info) const {
|
||||||
GenerateBatch(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config,
|
GenerateBatch(TConfig(), weights_u8, prefill, decode, runtime_config,
|
||||||
prompts, pos, kv_caches, pool, timing_info);
|
prompts, pos, kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -112,8 +119,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateSingleT>(
|
CallForModelAndWeight<GenerateSingleT>(
|
||||||
info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_,
|
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
|
||||||
runtime_config, prompt, start_pos, kv_cache, pool_, timing_info);
|
prompt, start_pos, kv_cache, pool_, timing_info);
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
@ -126,8 +133,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateBatchT>(
|
CallForModelAndWeight<GenerateBatchT>(
|
||||||
info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_,
|
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
|
||||||
runtime_config, prompts, start_pos, kv_caches, pool_, timing_info);
|
prompts, start_pos, kv_caches, pool_, timing_info);
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
|
|
@ -95,8 +96,8 @@ class Gemma {
|
||||||
const ModelInfo& Info() const { return info_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||||
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
const Activations& Prefill() const { return prefill_; }
|
||||||
const ByteStorageT& Decode() const { return decode_u8_; }
|
const Activations& Decode() const { return decode_; }
|
||||||
|
|
||||||
void Generate(const RuntimeConfig& runtime_config,
|
void Generate(const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
|
@ -114,8 +115,8 @@ class Gemma {
|
||||||
// Type-erased so that this can be defined in the header, without requiring
|
// Type-erased so that this can be defined in the header, without requiring
|
||||||
// forwarding functions.
|
// forwarding functions.
|
||||||
ByteStorageT weights_u8_;
|
ByteStorageT weights_u8_;
|
||||||
ByteStorageT prefill_u8_;
|
Activations prefill_;
|
||||||
ByteStorageT decode_u8_;
|
Activations decode_;
|
||||||
ModelInfo info_;
|
ModelInfo info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
13
gemma/ops.h
13
gemma/ops.h
|
|
@ -1140,29 +1140,26 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||||
template <size_t kBatchSize, typename WeightT, typename OutT>
|
template <typename WeightT, typename OutT>
|
||||||
void RMSNormBatched(size_t num_tokens, const float* activations,
|
void RMSNormBatched(size_t num_tokens, const float* activations,
|
||||||
const WeightT* weights, OutT* out, const size_t model_dim) {
|
const WeightT* weights, OutT* out, const size_t model_dim) {
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
RMSNorm(activations + token_idx * model_dim, weights,
|
RMSNorm(activations + token_idx * model_dim, weights,
|
||||||
out + token_idx * model_dim, model_dim);
|
out + token_idx * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename WeightT, typename InOutT>
|
// TODO: pass RowVectorBatch argument.
|
||||||
|
template <typename WeightT, typename InOutT>
|
||||||
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
||||||
InOutT* inout, const size_t model_dim) {
|
InOutT* inout, const size_t model_dim) {
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize>
|
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
|
||||||
void AddFromBatched(size_t num_tokens, const float* other, float* x,
|
float* x, const size_t model_dim) {
|
||||||
const size_t model_dim) {
|
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
||||||
model_dim);
|
model_dim);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue