mirror of https://github.com/google/gemma.cpp.git
Add MQA support
This commit is contained in:
parent
130e1f678f
commit
6923aec853
|
|
@ -54,7 +54,7 @@ struct ConfigGemma2B {
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||||
static constexpr int kHeads = 8;
|
static constexpr int kHeads = 8;
|
||||||
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
52
gemma.cc
52
gemma.cc
|
|
@ -70,12 +70,13 @@ template <class TConfig>
|
||||||
struct Layer {
|
struct Layer {
|
||||||
Layer() = default;
|
Layer() = default;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
||||||
// 3x for (query, key, value)
|
static constexpr size_t kQKVEinsumWSize =
|
||||||
static constexpr size_t kQKVEinsumWSize = 3 * kHeads * kQKVDim * kModelDim;
|
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||||
// 2x for (gelu gating vector, gated vector)
|
// 2x for (gelu gating vector, gated vector)
|
||||||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||||
|
|
||||||
|
|
@ -313,26 +314,28 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
static constexpr size_t kModelDim =
|
static constexpr size_t kModelDim =
|
||||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
static const float kQueryScale =
|
static const float kQueryScale =
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||||
|
|
||||||
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
// linear projections to QKV
|
// linear projections to QKV
|
||||||
const size_t head_offset =
|
constexpr const size_t head_offset =
|
||||||
3 * kQKVDim * kModelDim; // 3x for QKV dimensions
|
kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim;
|
||||||
const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim;
|
const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim;
|
||||||
const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim;
|
|
||||||
const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim;
|
|
||||||
|
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||||
|
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
|
||||||
|
|
||||||
MatVecLoop<kQKVDim, kModelDim>(
|
MatVecLoop<kQKVDim, kModelDim>(
|
||||||
c_layer->c_qkv_einsum_w, q_offset,
|
c_layer->c_qkv_einsum_w, q_offset,
|
||||||
activations.pre_att_rms_out.data() + batch_offset, q);
|
activations.pre_att_rms_out.data() + batch_offset, q);
|
||||||
|
|
||||||
|
if constexpr (kHeads == kKVHeads) {
|
||||||
|
const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim;
|
||||||
|
const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim;
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||||
|
|
||||||
|
|
@ -342,18 +345,40 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
kv_cache.key_cache.get() + kv_offset,
|
kv_cache.key_cache.get() + kv_offset,
|
||||||
kv_cache.value_cache.get() + kv_offset);
|
kv_cache.value_cache.get() + kv_offset);
|
||||||
|
|
||||||
|
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if constexpr (kHeads != kKVHeads) {
|
||||||
|
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
||||||
|
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
||||||
|
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||||
|
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
|
|
||||||
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||||
|
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||||
|
activations.pre_att_rms_out.data() + batch_offset,
|
||||||
|
kv_cache.key_cache.get() + kv_offset,
|
||||||
|
kv_cache.value_cache.get() + kv_offset);
|
||||||
|
|
||||||
|
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
// Calculate scores
|
// Calculate scores
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||||
head * TConfig::kSeqLen +
|
head * TConfig::kSeqLen +
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
|
|
||||||
Rope(q, kQKVDim, pos);
|
Rope(q, kQKVDim, pos);
|
||||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset = kHeads == kKVHeads
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||||
|
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2] = score;
|
head_att[pos2] = score;
|
||||||
|
|
@ -365,8 +390,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset = kHeads == kKVHeads
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||||
|
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,26 +72,12 @@ parser.add_argument(
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def expand_qkv(qkv_proj: np.array) -> np.array:
|
|
||||||
"""This won't be needed anymore when MQA is implemented"""
|
|
||||||
assert qkv_proj.shape == (2560, 2048)
|
|
||||||
qkv = qkv_proj.reshape((10, 256, 2048))
|
|
||||||
|
|
||||||
q_proj = qkv[:8].reshape((1,8,256,2048))
|
|
||||||
kv_proj = qkv[8:]
|
|
||||||
kv_proj = kv_proj[:, np.newaxis, :, :]
|
|
||||||
kv_proj = np.repeat(kv_proj, 8, axis=1)
|
|
||||||
|
|
||||||
qkv = np.concatenate([q_proj, kv_proj])
|
|
||||||
qkv = np.transpose(qkv, axes=[1,0,2,3])
|
|
||||||
return qkv
|
|
||||||
|
|
||||||
TRANSFORMATIONS = {
|
TRANSFORMATIONS = {
|
||||||
"2b":defaultdict(
|
"2b":defaultdict(
|
||||||
lambda: lambda x: x,
|
lambda: lambda x: x,
|
||||||
{
|
{
|
||||||
"embedder.weight": lambda x: x,
|
"embedder.weight": lambda x: x,
|
||||||
"self_attn.qkv_proj.weight": expand_qkv,
|
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
|
||||||
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
|
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
|
||||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
|
|
@ -115,7 +101,7 @@ VALIDATIONS = {
|
||||||
"2b": {
|
"2b": {
|
||||||
"embedder.weight": lambda x: x.shape == (256000, 2048),
|
"embedder.weight": lambda x: x.shape == (256000, 2048),
|
||||||
"model.norm.weight": lambda x: x.shape == (2048,),
|
"model.norm.weight": lambda x: x.shape == (2048,),
|
||||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
|
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
|
||||||
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
||||||
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||||
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue