Add first version of backpropagation support.

This is still in progress / experimental, currently it is only
implemented for normal gemma MQA attention layers, and no
parallelism is added yet for backward pass.

Since we need to remember all activations from all layers, the
forward pass was also reimplemented with a new activation data
structure.
This commit is contained in:
Zoltan Szabadka 2024-06-03 14:42:35 +00:00
parent ed8f39c058
commit 36e4d8bbfe
31 changed files with 4061 additions and 155 deletions

View File

@ -47,9 +47,27 @@ set(SOURCES
compression/sfp-inl.h compression/sfp-inl.h
compression/test_util.h compression/test_util.h
gemma/configs.h gemma/configs.h
gemma/activations.cc
gemma/activations.h
gemma/backward.cc
gemma/backward.h
gemma/backward-inl.h
gemma/backward_scalar.h
gemma/common.h
gemma/common-inl.h
gemma/common_scalar.cc
gemma/common_scalar.h
gemma/forward.cc
gemma/forward.h
gemma/forward-inl.h
gemma/forward_scalar.h
gemma/gemma.cc gemma/gemma.cc
gemma/gemma.h gemma/gemma.h
gemma/ops.h gemma/ops.h
gemma/optimizer.cc
gemma/optimizer.h
gemma/weights.cc
gemma/weights.h
util/app.h util/app.h
util/args.h util/args.h
) )
@ -100,9 +118,15 @@ target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohm
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS) if (GEMMA_ENABLE_TESTS)
enable_testing()
include(GoogleTest)
set(GEMMA_TEST_FILES set(GEMMA_TEST_FILES
gemma/ops_test.cc gemma/ops_test.cc
gemma/gemma_test.cc gemma/gemma_test.cc
gemma/backward_test.cc
gemma/backward_scalar_test.cc
gemma/optimize_test.cc
) )
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)

View File

@ -484,9 +484,10 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
const hn::ScalableTag<OutT> d; const hn::ScalableTag<OutT> d;
const size_t ofs = idx_batch * kBatch; const size_t ofs = idx_batch * kBatch;
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch; const size_t batch =
idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
Traits::Decompress(d, compressed.size(), compressed.data(), Traits::Decompress(d, compressed.size(), compressed.data(),
compressed_ofs + ofs, out + ofs, num); compressed_ofs + ofs, out + ofs, batch);
}); });
const double t1 = hwy::platform::Now(); const double t1 = hwy::platform::Now();
@ -495,6 +496,16 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
fprintf(stderr, "Decompress %.1f MB/s\n", mbps); fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
} }
// Returns dot product with `vec_aligned` of length `num`.
template <bool kVecEO, class DF, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const std::array<float, kCapacity>& w, size_t ofs,
const VecT* x, size_t num) {
HWY_DASSERT(ofs + num <= kCapacity);
HWY_DASSERT(hn::IsAligned(df, x));
using Traits = CompressTraits<float>;
return Traits::Dot(df, w.size(), w.data(), ofs, x, num);
}
// Returns dot product with `vec_aligned` of length `num`. // Returns dot product with `vec_aligned` of length `num`.
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT> template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed, HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,

38
gemma/activations.cc Normal file
View File

@ -0,0 +1,38 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
namespace gcpp {
ByteStorageT AllocateForwardPass(Model model) {
switch (model) {
case Model::GEMMA_2B:
return ForwardPass<float, ConfigGemma2B>::Allocate();
case Model::GEMMA_7B:
return ForwardPass<float, ConfigGemma7B>::Allocate();
case Model::GRIFFIN_2B:
return ForwardPass<float, ConfigGriffin2B>::Allocate();
case Model::GEMMA_TINY:
return ForwardPass<float, ConfigGemmaTiny>::Allocate();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

84
gemma/activations.h Normal file
View File

@ -0,0 +1,84 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#include "gemma/common.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
template <typename T, typename TConfig>
struct ForwardLayer {
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
std::array<T, kSeqLen * kModelDim> input;
std::array<T, kSeqLen * kModelDim> pre_att_rms_out;
std::array<T, kSeqLen * (kHeads + 2) * kQKVDim> qkv;
std::array<T, kSeqLen * kHeads * kSeqLen> att;
std::array<T, kSeqLen * kHeads * kQKVDim> att_out;
std::array<T, kSeqLen * kModelDim> att_post1;
std::array<T, kSeqLen * kModelDim> attention_out;
std::array<T, kSeqLen * kModelDim> bf_pre_ffw_rms_out;
std::array<T, kSeqLen * kFFHiddenDim * 2> ffw_hidden;
std::array<T, kSeqLen * kFFHiddenDim> ffw_hidden_gated;
};
template <typename T, typename TConfig>
struct ForwardPass {
ForwardPass() {} // prevents placement-new calling memset
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
std::array<ForwardLayer<T, TConfig>, kLayers> layers;
std::array<T, kSeqLen * kModelDim> final_layer_output;
std::array<T, kSeqLen * kModelDim> final_norm_output;
std::array<T, kSeqLen * kVocabSize> logits;
std::array<T, kSeqLen * kVocabSize> probs;
static ByteStorageT Allocate() {
return hwy::AllocateAligned<uint8_t>(sizeof(ForwardPass<T, TConfig>));
}
};
template<typename T, typename TConfig>
class ActivationsWrapper {
using WrappedT = ForwardPass<T, TConfig>;
public:
ActivationsWrapper()
: data_(WrappedT::Allocate()),
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
const WrappedT& get() const { return activations_; }
WrappedT& get() { return activations_; }
private:
ByteStorageT data_;
WrappedT& activations_;
};
ByteStorageT AllocateForwardPass(Model model);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_

412
gemma/backward-inl.h Normal file
View File

@ -0,0 +1,412 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "gemma/activations.h"
#include "gemma/prompt.h"
#include "gemma/weights.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_BACkWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#endif
#include "gemma/common-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <size_t kCols, size_t kRows>
void MatMulVJP(const std::array<float, kRows * kCols>& weights,
const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens,
std::array<float, kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) {
memset(grad_x, 0, num_tokens * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t voffs = pos * kRows;
const size_t xoffs = pos * kCols;
for (size_t j = 0; j < kRows; ++j) {
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols);
MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs],
kCols);
}
}
}
template <size_t kHeads, size_t kCols, size_t kRows>
void MultiHeadMatMulVJP(
const std::array<float, kHeads * kRows * kCols>& weights,
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens,
std::array<float, kHeads * kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
hwy::ThreadPool& pool) {
memset(grad_x, 0, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t j = 0; j < kRows; ++j) {
for (size_t h = 0; h < kHeads; ++h) {
MulByConstAndAdd(v[pos * kRows + j],
&x[pos * kHeads * kCols + h * kCols],
&grad_w[h * kRows * kCols + j * kCols], kCols);
MulByConstAndAdd(v[pos * kRows + j],
&weights[h * kRows * kCols + j * kCols],
&grad_x[pos * kHeads * kCols + h * kCols], kCols);
}
}
}
}
template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
// kSqrtOverPi*3*kMul
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
const hn::Vec<D> v2 = hn::Mul(v, v);
const hn::Vec<D> v3 = hn::Mul(v2, v);
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
const hn::Vec<D> tanh = hn::Tanh(d, arg);
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
}
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto offset =
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
hn::Transform1(
d, backward, size, forward,
[&offset](const auto d, const auto v, const auto y)
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
}
static HWY_NOINLINE void RMSNormVJP(
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
constexpr float eps = 1e-6f;
float ss = SquaredL2(x + offset, model_dim);
ss = 1.0f / sqrtf(ss / StaticCast<float>(model_dim) + eps);
for (size_t i = 0; i < model_dim; ++i) {
grad_w[i] += v[offset + i] * x[offset + i] * ss;
}
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
float tmp = 0.0f;
for (size_t i = 0; i < model_dim; ++i) {
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
}
tmp *= ss3;
for (size_t i = 0; i < model_dim; ++i) {
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
tmp * x[offset + i];
}
}
}
static HWY_NOINLINE void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt,
const float scaling, const float* HWY_RESTRICT backward,
float* HWY_RESTRICT grad, size_t model_dim) {
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) {
int token = prompt[pos];
MulByConstAndAdd(scaling, backward + pos * model_dim,
grad + token * model_dim, model_dim);
}
}
template <typename TConfig>
void LayerVJP(const Layer<float, TConfig>& weights,
const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad,
size_t num_tokens,
Layer<float, TConfig>& grad,
ForwardLayer<float, TConfig>& backward,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
MatMulVJP<kFFHiddenDim, kModelDim>(
weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad,
num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(),
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim;
const float* HWY_RESTRICT b_out_gated =
backward.ffw_hidden_gated.data() + pos * kFFHiddenDim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = Load(df, f_out + i);
const auto x = Load(df, f_out_mul + i);
const auto v = Load(df, b_out_gated + i);
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
}
}
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
weights.gating_einsum_w,
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
num_tokens, grad.gating_einsum_w,
backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(),
kModelDim, num_tokens,
grad.pre_ffw_norm_scale.data(),
backward.attention_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * kModelDim,
backward.attention_out.data() + pos * kModelDim, kModelDim);
}
hwy::ZeroBytes(backward.qkv.data(),
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
weights.attn_vec_einsum_w, forward.att_out.data(),
backward.attention_out.data(), num_tokens,
grad.attn_vec_einsum_w, backward.att_out.data(), pool);
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
const float* HWY_RESTRICT b_att_out =
backward.att_out.data() + (pos * kHeads + head) * kQKVDim;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim);
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim);
}
}
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
}
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim);
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim);
}
}
}
for (int pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(b_kv, kQKVDim, -pos);
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_q =
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
MulByConst(kQueryScale, b_q, kQKVDim);
Rope(b_q, kQKVDim, -pos);
}
}
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
weights.qkv_einsum_w, forward.pre_att_rms_out.data(),
backward.qkv.data(), num_tokens,
grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(),
forward.input.data(),
backward.pre_att_rms_out.data(),
kModelDim, num_tokens,
grad.pre_attention_norm_scale.data(),
backward.input.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(backward.attention_out.data() + pos * kModelDim,
backward.input.data() + pos * kModelDim, kModelDim);
}
}
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const float cap,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto one = hn::Set(d, 1.0f);
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform1(
d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y);
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
}
static HWY_NOINLINE void CrossEntropyLossGrad(
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
const Prompt& prompt, size_t vocab_size) {
const float scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1;
memset(grad, 0, num_tokens * vocab_size * sizeof(grad[0]));
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) {
if (i + 1 < prompt.context_size) {
continue;
}
const int next_token = prompt.tokens[i + 1];
grad[i * vocab_size + next_token] =
scaling / x[i * vocab_size + next_token];
}
}
template <typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<float, TConfig>& weights,
const ForwardPass<float, TConfig>& forward,
Weights<float, TConfig>& grad,
ForwardPass<float, TConfig>& backward,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale);
static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(prompt.context_size > 0);
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
const size_t num_tokens = prompt.tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize);
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize,
kVocabSize);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize);
}
MatMulVJP<kModelDim, kVocabSize>(
weights.embedder_input_embedding, forward.final_norm_output.data(),
backward.logits.data(), num_tokens,
grad.embedder_input_embedding, backward.final_norm_output.data(),
pool);
RMSNormVJP(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(),
kModelDim, num_tokens,
grad.final_norm_scale.data(),
backward.final_layer_output.data(), pool);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
auto type = TConfig::kLayerConfig[layer];
// TODO(szabadka) Implement Griffin layer vjp.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

88
gemma/backward.cc Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/backward.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/backward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "gemma/backward-inl.h"
#include "gemma/weights.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ByteStorageT& weights_u8,
const ByteStorageT& forward_u8,
ByteStorageT& grad_u8,
ByteStorageT& backward_u8,
hwy::ThreadPool& pool) {
using TWeights = WeightsF<TConfig>;
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool);
}
void CrossEntropyLossBackwardPassT(Model model,
const Prompt& prompt,
const ByteStorageT& weights,
const ByteStorageT& forward,
ByteStorageT& grad,
ByteStorageT& backward,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B>(
prompt, weights, forward, grad, backward, pool);
break;
case Model::GEMMA_TINY:
CrossEntropyLossBackwardPass<ConfigGemmaTiny>(
prompt, weights, forward, grad, backward, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossBackwardPassT);
void CrossEntropyLossBackwardPass(
const Model& model, const Prompt& prompt,
const ByteStorageT& weights, const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
model, prompt, weights, forward, grad, backward, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

34
gemma/backward.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#include <vector>
#include "gemma/common.h"
#include "gemma/prompt.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void CrossEntropyLossBackwardPass(
const Model& model, const Prompt& prompt,
const ByteStorageT& weights, const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_

351
gemma/backward_scalar.h Normal file
View File

@ -0,0 +1,351 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <complex>
#include <vector>
#include "gemma/activations.h"
#include "gemma/common_scalar.h"
#include "gemma/prompt.h"
#include "gemma/weights.h"
namespace gcpp {
template<typename T>
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t M, size_t K) {
memset(dx, 0, M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
}
}
}
template<typename T>
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t H, size_t N, size_t M, size_t K) {
memset(dx, 0, H * M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
for (size_t h = 0; h < H; ++h) {
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
&dw[h * N * M + j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
&dx[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const size_t offset = i * N;
constexpr T eps(1e-6);
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; ++j) {
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
}
const T ss3 = ss * ss * ss / T(N);
T tmp = 0.0;
for (size_t j = 0; j < N; ++j) {
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
}
tmp *= ss3;
for (size_t j = 0; j < N; ++j) {
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
}
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += y[i] * dy[i];
}
for (size_t i = 0; i < N; ++i) {
dy[i] = y[i] * (dy[i] - sum);
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
SoftmaxVJPT(y + i * N, dy + i * N, N);
}
}
template<typename T>
T GeluDerivative(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
const T x2 = x * x;
const T x3 = x2 * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T tanh = std::tanh(arg);
const T cdf = T(0.5) * (T(1.0) + tanh);
const T dtanh = T(1.0) - tanh * tanh;
const T darg = kMul2 * x2 + kSqrt2OverPi;
return T(0.5) * x * dtanh * darg + cdf;
}
template<typename T>
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
const T* v = d_out + i * N;
T* dx1 = d_in + i * 2 * N;
T* dx2 = dx1 + N;
for (size_t j = 0; j < N; ++j) {
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
dx2[j] = v[j] * Gelu(x1[j]);
}
}
}
template<typename T>
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
size_t num_tokens, size_t kHeads, size_t kQKVDim,
size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * (kHeads + 2) * kQKVDim;
memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
const T* q = qkv + qoffs;
const T* dout = doutput + aoffs;
T* dq = dqkv + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const T* k = qkv + koffs;
T* dk = dqkv + koffs;
MulByConstAndAddT(dout[pos2], k, dq, kQKVDim);
MulByConstAndAddT(dout[pos2], q, dk, kQKVDim);
}
}
}
}
template<typename T>
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens,
size_t kHeads, size_t kSeqLen) {
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
}
}
}
template<typename T>
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
T* dqkv, T* dattention, size_t num_tokens,
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
auto v_offset = [&](size_t pos) {
return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim;
};
for (size_t pos = 0; pos < num_tokens; ++pos) {
memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim;
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
const T* att = &attention[aoffset];
const T* dout = &doutput[offset];
T* datt = &dattention[aoffset];
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim);
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim);
}
}
}
}
template<typename T>
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
const T* dy, T* dw, size_t N) {
for (size_t i = 0; i + 1 < tokens.size(); ++i) {
int token = tokens[i];
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
}
}
template<typename T, typename TConfig>
void LayerVJP(const Layer<T, TConfig>& weights,
const ForwardLayer<T, TConfig>& forward,
const T* dy,
Layer<T, TConfig>& grad,
ForwardLayer<T, TConfig>& backward,
size_t num_tokens) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim));
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
kModelDim, kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim,
num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(),
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
kModelDim, num_tokens);
AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim);
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(),
grad.attn_vec_einsum_w.data(),
backward.att_out.data(),
kHeads, kModelDim, kQKVDim, num_tokens);
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
backward.att_out.data(), backward.qkv.data(),
backward.att.data(), num_tokens, kHeads, kQKVDim,
kSeqLen);
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(),
num_tokens, kHeads, kSeqLen);
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
}
for (int pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * kQKVDim, kQKVDim, -pos);
}
}
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), grad.qkv_einsum_w.data(),
backward.pre_att_rms_out.data(),
(kHeads + 2) * kQKVDim, kModelDim, num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(),
grad.pre_attention_norm_scale.data(),
backward.input.data(), kModelDim, num_tokens);
AddFromT(backward.attention_out.data(), backward.input.data(),
num_tokens * kModelDim);
}
template<typename T>
void SoftcapVJPT(const T* y, T* dy, size_t N) {
T cap = 30.0;
T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap;
dy[i] *= (T(1.0) - scaled * scaled);
}
}
template<typename T>
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
T scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1;
memset(dx, 0, V * num_tokens * sizeof(x[0]));
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) {
if (i + 1 < prompt.context_size) {
continue;
}
const int next_token = prompt.tokens[i + 1];
dx[i * V + next_token] = scaling / x[i * V + next_token];
}
}
template<typename T, typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<T, TConfig>& weights,
const ForwardPass<T, TConfig>& forward,
Weights<T, TConfig>& grad,
ForwardPass<T, TConfig>& backward) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
const size_t num_tokens = prompt.tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize);
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
kVocabSize, num_tokens);
SoftcapVJPT(forward.logits.data(), backward.logits.data(),
num_tokens * kVocabSize);
MatMulVJPT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(),
backward.logits.data(),
grad.embedder_input_embedding.data(),
backward.final_norm_output.data(),
kVocabSize, kModelDim, num_tokens);
RMSNormVJPT(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(),
grad.final_norm_scale.data(),
backward.final_layer_output.data(), kModelDim, num_tokens);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
T* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
}
const T kEmbScaling = EmbeddingScaling(kModelDim);
InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
prompt.tokens, kEmbScaling,
backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_

View File

@ -0,0 +1,610 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/backward_scalar.h"
#include <array>
#include <complex>
#include <random>
#include "gemma/forward_scalar.h"
#include "gemma/sampler.h"
#include "gemma/test_util.h"
#include "gtest/gtest.h"
namespace gcpp {
TEST(BackPropTest, MatMulVJP) {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kRows * kCols> weights;
std::array<T, kTokens * kCols> x;
std::array<T, kRows * kCols> grad;
std::array<T, kTokens * kCols> dx;
std::array<TC, kRows * kCols> c_weights;
std::array<TC, kTokens * kCols> c_x;
std::array<TC, kTokens * kRows> c_y;
std::array<T, kTokens * kRows> dy;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
memset(&grad, 0, sizeof(grad));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__);
TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__);
}
}
TEST(BackPropTest, MultiHeadMatMulVJP) {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kRows * kCols * kHeads> weights;
std::array<T, kTokens * kCols * kHeads> x;
std::array<T, kRows * kCols * kHeads> grad;
std::array<T, kTokens * kCols * kHeads> dx;
std::array<TC, kRows * kCols * kHeads> c_weights;
std::array<TC, kTokens * kCols * kHeads> c_x;
std::array<TC, kTokens * kRows> c_y;
std::array<T, kTokens * kRows> dy;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
memset(&grad, 0, sizeof(grad));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
dx.data(), kHeads, kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__);
}
}
TEST(BackPropTest, RMSNormVJP) {
static const size_t K = 2;
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> weights;
std::array<T, N> grad;
std::array<T, K * N> x;
std::array<T, K * N> dx;
std::array<T, K * N> dy;
std::array<TC, N> c_weights;
std::array<TC, K * N> c_x;
std::array<TC, K * N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
};
memset(&grad, 0, sizeof(grad));
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
}
}
TEST(BackPropTest, SoftmaxVJP) {
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> x;
std::array<T, N> dx;
std::array<T, N> dy;
std::array<TC, N> c_x;
std::array<TC, N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), sizeof(c_x));
Softmax(c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softmax(x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftmaxVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MaskedSoftmaxVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kTokens = 14;
static const size_t N = kHeads * kSeqLen * kSeqLen;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> x;
std::array<T, N> dy;
std::array<T, N> dx = {};
std::array<TC, N> c_x;
std::array<TC, N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(),
kTokens * kHeads * kSeqLen * sizeof(c_x[0]));
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
return DotT(dy.data(), c_y.data(), N);
};
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx[0]));
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, SoftcapVJP) {
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> x;
std::array<T, N> dx;
std::array<T, N> dy;
std::array<TC, N> c_x;
std::array<TC, N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
Softcap(c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softcap(x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftcapVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, CrossEntropyLossGrad) {
static const size_t K = 8;
static const size_t V = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, K * V> x;
std::array<T, K * V> dx;
std::array<TC, K * V> c_x;
Prompt prompt;
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
for (int iter = 0; iter < 10; ++iter) {
prompt.context_size = 1 + (iter % 6);
RandInit(x, 1.0 * (1 << iter), gen);
Softcap(x.data(), V * K);
Softmax(x.data(), V, K);
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
Complexify(x, c_x);
auto func = [&]() {
return CrossEntropyLoss(c_x.data(), prompt, V);
};
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
}
}
TEST(BackPropTest, GatedGeluVJP) {
static const size_t K = 2;
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, K * 2 * N> x;
std::array<T, K * 2 * N> dx;
std::array<T, K * N> dy;
std::array<TC, K * 2 * N> c_x;
std::array<TC, K * N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
GatedGelu(c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), N * K);
};
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MaskedAttentionVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kQKVDim = 8;
static const size_t kTokens = 14;
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
static const size_t kOutSize = kSeqLen * kHeads * kSeqLen;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kQKVSize> x;
std::array<T, kQKVSize> dx = {};
std::array<T, kOutSize> dy;
std::array<TC, kQKVSize> c_x;
std::array<TC, kOutSize> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
};
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MixByAttentionVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kQKVDim = 8;
static const size_t kTokens = 14;
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen;
static const size_t kOutSize = kSeqLen * kHeads * kQKVDim;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kQKVSize> qkv;
std::array<T, kQKVSize> dqkv = {};
std::array<T, kAttnSize> attn;
std::array<T, kAttnSize> dattn = {};
std::array<T, kOutSize> dy;
std::array<TC, kQKVSize> c_qkv;
std::array<TC, kAttnSize> c_attn;
std::array<TC, kOutSize> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(qkv, 1.0, gen);
RandInit(attn, 1.0, gen);
Complexify(qkv, c_qkv);
Complexify(attn, c_attn);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
};
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, InputEmbeddingVJP) {
static const size_t kSeqLen = 8;
static const size_t kVocabSize = 4;
static const size_t kModelDim = 16;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kVocabSize * kModelDim> weights;
std::array<T, kVocabSize * kModelDim> grad;
std::array<T, kSeqLen * kModelDim> dy;
std::array<TC, kVocabSize * kModelDim> c_weights;
std::array<TC, kSeqLen * kModelDim> c_y;
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
size_t num_tokens = tokens.size() - 1;
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
auto func = [&]() {
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
};
memset(&grad, 0, sizeof(grad));
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
kModelDim);
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
}
}
struct TestConfig {
static constexpr int kSeqLen = 18;
static constexpr int kVocabSize = 12;
static constexpr int kModelDim = 32;
static constexpr int kHeads = 3;
static constexpr int kQKVDim = 12;
static constexpr int kFFHiddenDim = 48;
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
};
TEST(BackPropTest, LayerVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim;
Layer<T, TestConfig> weights;
Layer<T, TestConfig> grad;
ForwardLayer<T, TestConfig> forward;
ForwardLayer<T, TestConfig> backward = {};
Layer<TC, TestConfig> c_weights;
ForwardLayer<TC, TestConfig> c_forward;
std::array<T, kOutputSize> y;
std::array<T, kOutputSize> dy;
std::array<TC, kOutputSize> c_y;
const size_t num_tokens = 3;
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
RandInit(forward.input, 1.0, gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(forward.input, c_forward.input);
auto func = [&]() {
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim);
};
memset(&grad, 0, sizeof(grad));
ApplyLayer(weights, forward, num_tokens, y.data());
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
TestGradient(backward.input, c_forward.input, func, 1e-11, 1e-11,
__LINE__);
TestGradient(grad, c_weights, func, 1e-11);
}
}
TEST(BackPropTest, EndToEnd) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
WeightsWrapper<T, TestConfig> weights;
WeightsWrapper<T, TestConfig> grad;
ForwardPass<T, TestConfig> forward;
ForwardPass<T, TestConfig> backward;
WeightsWrapper<TC, TestConfig> c_weights;
ForwardPass<TC, TestConfig> c_forward;
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(10, gen);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0, gen);
CrossEntropyLossForwardPass(prompt, weights.get(), forward);
grad.clear();
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward, grad.get(), backward);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 1e-11);
}
}
template<typename T, typename TConfig>
void MulByConstAndAddT(T c, const Layer<T, TConfig>& x,
Layer<T, TConfig>& out) {
MulByConstAndAddT(c, x.pre_attention_norm_scale,
out.pre_attention_norm_scale);
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w);
MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale);
MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w);
MulByConstAndAddT(c, x.linear_w, out.linear_w);
}
template<typename T, typename TConfig>
void MulByConstAndAddT(T c, const Weights<T, TConfig>& x,
Weights<T, TConfig>& out) {
static constexpr size_t kLayers = TConfig::kLayers;
MulByConstAndAddT(c, x.embedder_input_embedding,
out.embedder_input_embedding);
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) {
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
}
}
// Evaluates forward pass on a batch.
template<typename T, typename TConfig>
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
const WeightsWrapper<T, TConfig>& weights,
ForwardPass<T, TConfig>& forward) {
T loss = 0.0;
for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
}
T scale = 1.0 / batch.size();
return loss * scale;
}
// Evaluates forward pass on a batch by applying gradient with the given
// learning rate. Does not update weights, but uses the given tmp weights
// instead.
template<typename T, typename TConfig>
T CrossEntropyLossForwardPass(T learning_rate,
const std::vector<Prompt>& batch,
const WeightsWrapper<T, TConfig>& weights,
const WeightsWrapper<T, TConfig>& grad,
WeightsWrapper<T, TConfig>& tmp,
ForwardPass<T, TConfig>& forward) {
tmp.copy(weights);
const T scale = -learning_rate / batch.size();
MulByConstAndAddT(scale, grad.get(), tmp.get());
return CrossEntropyLossForwardPass(batch, tmp, forward);
}
// Uses line search in the negative gradient direction to update weights. We do
// this so that we can test that each step during the gradient descent can
// decrease the objective function value.
template<typename T, typename TConfig>
T FindOptimalUpdate(const WeightsWrapper<T, TConfig>& grad,
WeightsWrapper<T, TConfig>& weights,
WeightsWrapper<T, TConfig>& tmp,
ForwardPass<T, TConfig>& forward,
const std::vector<Prompt>& batch,
T loss, T initial_learning_rate) {
T lr0 = initial_learning_rate;
T loss0 = CrossEntropyLossForwardPass(
lr0, batch, weights, grad, tmp, forward);
for (size_t iter = 0; iter < 30; ++iter) {
T lr1 = lr0 * 0.5;
T loss1 = CrossEntropyLossForwardPass(
lr1, batch, weights, grad, tmp, forward);
if (loss0 < loss && loss1 >= loss0) {
break;
}
loss0 = loss1;
lr0 = lr1;
}
for (size_t iter = 0; iter < 30; ++iter) {
T lr1 = lr0 * 2.0;
T loss1 = CrossEntropyLossForwardPass(
lr1, batch, weights, grad, tmp, forward);
if (loss1 >= loss0) {
break;
}
loss0 = loss1;
lr0 = lr1;
}
const T scale = -lr0 / batch.size();
MulByConstAndAddT(scale, grad.get(), weights.get());
return lr0;
}
TEST(BackProptest, Convergence) {
std::mt19937 gen(42);
using T = float;
using TC = std::complex<double>;
WeightsWrapper<T, TestConfig> weights;
WeightsWrapper<T, TestConfig> grad;
WeightsWrapper<T, TestConfig> tmp;
ForwardPass<T, TestConfig> forward;
ForwardPass<T, TestConfig> backward;
WeightsWrapper<TC, TestConfig> c_weights;
ForwardPass<TC, TestConfig> c_forward;
constexpr size_t kBatchSize = 5;
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
T learning_rate = 0.01;
RandInit(weights.get(), T(1.0), gen);
printf("Sample batch:\n");
for (size_t i = 0; i < 10; ++i) {
ReverseSequenceSampler::LogPrompt(training_task.Sample(gen));
}
T prev_loss = std::numeric_limits<T>::max();
bool stop = false;
size_t step = 0;
while (!stop) {
T loss = 0.0;
grad.clear();
std::mt19937 sgen(42);
std::vector<Prompt> batch = training_task.SampleBatch(kBatchSize, sgen);
for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward, grad.get(), backward);
}
if (step % 250 == 0) {
printf("Checking gradient...\n");
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
TC scale = batch.size();
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
};
TestGradient(grad.get(), c_weights.get(), func, 5e-3f);
}
loss /= batch.size();
EXPECT_LT(loss, prev_loss);
stop = step >= 10000 || loss < 1e-2;
if (step % 10 == 0 || stop) {
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
step, loss, learning_rate);
}
if (!stop) {
learning_rate = FindOptimalUpdate(
grad, weights, tmp, forward, batch, loss, learning_rate);
++step;
}
prev_loss = loss;
}
EXPECT_LT(step, 1000);
}
} // namespace gcpp

270
gemma/backward_test.cc Normal file
View File

@ -0,0 +1,270 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <algorithm>
#include <array>
#include <complex>
#include <random>
#include <vector>
#include "compression/compress.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "gemma/backward_scalar.h"
#include "gemma/forward_scalar.h"
#include "gemma/gemma.h"
#include "gemma/sampler.h"
#include "gemma/test_util.h"
#include "gemma/weights.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/backward_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
#include "gemma/backward-inl.h"
#include "gemma/forward-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
hwy::ThreadPool pool(8);
std::mt19937 gen(42);
HWY_ALIGN std::array<float, kRows * kCols> weights;
HWY_ALIGN std::array<float, kTokens * kCols> x;
HWY_ALIGN std::array<float, kTokens * kRows> dy;
HWY_ALIGN std::array<float, kRows * kCols> grad;
HWY_ALIGN std::array<float, kTokens * kCols> dx;
HWY_ALIGN std::array<float, kRows * kCols> grad_scalar;
HWY_ALIGN std::array<float, kTokens * kCols> dx_scalar;
using TC = std::complex<double>;
std::array<TC, kRows * kCols> c_weights;
std::array<TC, kTokens * kCols> c_x;
std::array<TC, kTokens * kRows> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
memset(&grad, 0, sizeof(grad));
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
grad, dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
memset(&grad_scalar, 0, sizeof(grad_scalar));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 0, 0, __LINE__);
TestNear(grad, grad_scalar, 0, 0, __LINE__);
}
}
void TestMultiHeadMatMulVJP() {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
hwy::ThreadPool pool(8);
std::mt19937 gen(42);
HWY_ALIGN std::array<float, kRows * kCols * kHeads> weights;
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> x;
HWY_ALIGN std::array<float, kRows * kCols * kHeads> grad;
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> dx;
HWY_ALIGN std::array<float, kTokens * kRows> dy;
HWY_ALIGN std::array<float, kRows * kCols * kHeads> grad_scalar;
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> dx_scalar;
using TC = std::complex<double>;
std::array<TC, kRows * kCols * kHeads> c_weights;
std::array<TC, kTokens * kCols * kHeads> c_x;
std::array<TC, kTokens * kRows> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
memset(&grad, 0, sizeof(grad));
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
memset(&grad_scalar, 0, sizeof(grad_scalar));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 0, 0, __LINE__);
TestNear(grad, grad_scalar, 0, 0, __LINE__);
}
}
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
hwy::ThreadPool pool(8);
std::mt19937 gen(42);
HWY_ALIGN std::array<float, N> weights;
HWY_ALIGN std::array<float, K * N> x;
HWY_ALIGN std::array<float, N> grad;
HWY_ALIGN std::array<float, K * N> dx;
HWY_ALIGN std::array<float, K * N> dy;
HWY_ALIGN std::array<float, N> grad_scalar;
HWY_ALIGN std::array<float, K * N> dx_scalar;
using TC = std::complex<double>;
std::array<TC, N> c_weights;
std::array<TC, K * N> c_x;
std::array<TC, K * N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
};
memset(&grad, 0, sizeof(grad));
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
memset(&grad_scalar, 0, sizeof(grad_scalar));
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), N, K);
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
}
}
struct TestConfig {
static constexpr int kSeqLen = 24;
static constexpr int kVocabSize = 16;
static constexpr int kModelDim = 32;
static constexpr int kHeads = 3;
static constexpr int kQKVDim = 16;
static constexpr int kFFHiddenDim = 64;
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
};
void TestEndToEnd() {
std::mt19937 gen(42);
hwy::ThreadPool pool(0);
WeightsWrapper<float, TestConfig> weights;
WeightsWrapper<float, TestConfig> grad;
ActivationsWrapper<float, TestConfig> forward0;
ActivationsWrapper<float, TestConfig> forward1;
ActivationsWrapper<float, TestConfig> backward;
using TC = std::complex<double>;
WeightsWrapper<TC, TestConfig> c_weights;
ForwardPass<TC, TestConfig> c_forward;
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(10, gen);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0f, gen);
float loss0 = CrossEntropyLossForwardPass(
prompt, weights.get(), forward0.get());
float loss1 = CrossEntropyLossForwardPass<TestConfig, WeightsF, LayerF>(
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
pool);
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 1e-5);
grad.clear();
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
pool);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 2e-3f);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(BackwardTest);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
#ifdef HWY_AFTER_TEST
HWY_AFTER_TEST();
#endif
} // namespace gcpp
#endif

71
gemma/common-inl.h Normal file
View File

@ -0,0 +1,71 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "gemma/activations.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#endif
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

30
gemma/common.h Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include "hwy/aligned_allocator.h"
namespace gcpp {
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

50
gemma/common_scalar.cc Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "gemma/ops.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
float EmbeddingScaling(int model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(model_dim))));
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(EmbeddingScaling);
float EmbeddingScaling(int model_dim) {
return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim);
}
} // namespace gcpp
#endif // HWY_ONCE

119
gemma/common_scalar.h Normal file
View File

@ -0,0 +1,119 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#include <complex>
namespace gcpp {
template<typename T, typename U>
U DotT(const T* a, const U* b, size_t N) {
U sum = {};
for (size_t i = 0; i < N; ++i) {
sum += a[i] * b[i];
}
return sum;
}
template<>
std::complex<double> DotT(const float* a, const std::complex<double>* b,
size_t N) {
std::complex<double> sum = {};
for (size_t i = 0; i < N; ++i) {
sum += static_cast<double>(a[i]) * b[i];
}
return sum;
}
template<typename T>
void MulByConstT(T c, T* x, size_t N) {
for (size_t i = 0; i < N; ++i) {
x[i] *= c;
}
}
// out += c * x
template<typename T>
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += c * x[i];
}
}
template<typename T, size_t N>
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
MulByConstAndAddT(c, x.data(), out.data(), N);
}
template<typename T>
void AddFromT(const T* a, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += a[i];
}
}
template<typename T>
T SquaredL2(const T* x, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += x[i] * x[i];
}
return sum;
}
template<typename T>
T Gelu(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
const T x3 = x * x * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
return x * cdf;
}
template<typename T, typename U>
void Rope(T* x, U base, size_t N, int i) {
const size_t N2 = N / 2;
for (size_t dim = 0; dim < N2; ++dim) {
const T freq_exponents = T(2 * dim) / T(N);
const T timescale = std::pow(base, freq_exponents);
const T theta = T(i) / timescale;
const T cos_val = std::cos(theta);
const T sin_val = std::sin(theta);
const T x0 = x[dim];
const T x1 = x[dim + N2];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + N2] = x0 * sin_val + x1 * cos_val;
}
}
template<typename T>
void Rope(T* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
template<typename T>
void Rope(std::complex<T>* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
float EmbeddingScaling(int model_dim);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_

View File

@ -142,6 +142,38 @@ struct ConfigGemma2B {
using WeightT = GEMMA_WEIGHT_T; using WeightT = GEMMA_WEIGHT_T;
}; };
struct ConfigGemmaTiny {
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 64;
static constexpr int kFFHiddenDim = 128;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
struct ConfigGriffin2B { struct ConfigGriffin2B {
// Griffin uses local attention, so kSeqLen is actually the local attention // Griffin uses local attention, so kSeqLen is actually the local attention
// window. // window.

290
gemma/forward-inl.h Normal file
View File

@ -0,0 +1,290 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "gemma/common-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output,
size_t model_dim) {
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) {
int token = prompt[pos];
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim);
}
}
template<typename WT, typename XT, typename OutT>
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
size_t model_dim, size_t num_tokens,
OutT* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, output + offset, model_dim);
}
}
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
const std::vector<int>& prompt,
size_t context_size,
size_t vocab_size,
hwy::ThreadPool& pool) {
float loss = 0.0f;
for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) {
if (pos + 1 < context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = prompt[pos + 1];
loss += std::log(probs[pos * vocab_size + next_token]);
}
float scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename TConfig, template<typename> typename LayerT>
void ApplyForwardLayer(const LayerT<TConfig>& weights,
ForwardLayer<float, TConfig>& activations,
size_t num_tokens,
float* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
activations.input.data(), kModelDim, num_tokens,
activations.pre_att_rms_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
weights.qkv_einsum_w, 0,
activations.pre_att_rms_out.data() + pos * kModelDim, nullptr,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const float* HWY_RESTRICT k2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
Softmax(head_att, pos + 1);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
float* HWY_RESTRICT att_out =
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
float* HWY_RESTRICT v2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
});
hwy::ZeroBytes(activations.attention_out.data(),
num_tokens * kModelDim * sizeof(activations.attention_out[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
MatVec<kModelDim, kQKVDim>(
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
nullptr, activations.att_post1.data() + pos * kModelDim, pool);
AddFrom(activations.att_post1.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
activations.attention_out.data(), kModelDim, num_tokens,
activations.bf_pre_ffw_rms_out.data(), pool);
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kFFHiddenDim * 2, kModelDim>(
weights.gating_einsum_w, 0,
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT out =
activations.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
float* HWY_RESTRICT out_gated =
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = Load(df, out + i);
const auto x = Load(df, out_mul + i);
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kModelDim, kFFHiddenDim>(
weights.linear_w, 0,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
nullptr, output + pos * kModelDim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * kModelDim,
output + pos * kModelDim, kModelDim);
}
}
template <typename TConfig, template<typename> typename WeightsT,
template<typename> typename LayerT>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const WeightsT<TConfig>& weights,
ForwardPass<float, TConfig>& forward,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale);
static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(context_size > 0);
HWY_DASSERT(context_size < prompt.size());
const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling,
forward.layers[0].input.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
// TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < kLayers ?
forward.layers[layer + 1].input.data() :
forward.final_layer_output.data();
ApplyForwardLayer<TConfig, LayerT>(
*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, pool);
}
ApplyRMSNorm(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
kModelDim, num_tokens, forward.final_norm_output.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0,
forward.final_norm_output.data() + pos * kModelDim, nullptr,
forward.logits.data() + pos * kVocabSize, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
}
memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
}
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
kVocabSize, pool);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

79
gemma/forward.cc Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/forward.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "gemma/forward-inl.h"
#include "gemma/weights.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig>
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ByteStorageT& weights_u8,
ByteStorageT& forward_u8,
hwy::ThreadPool& pool) {
const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
auto& forward =
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
return CrossEntropyLossForwardPass<TConfig, WeightsF, LayerF>(
prompt.tokens, prompt.context_size, weights, forward, pool);
}
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B>(
prompt, weights, forward, pool);
case Model::GEMMA_TINY:
return CrossEntropyLossForwardPass<ConfigGemmaTiny>(
prompt, weights, forward, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
ByteStorageT& forward, hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
model, prompt, weights, forward, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

33
gemma/forward.h Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include <vector>
#include "gemma/common.h"
#include "gemma/prompt.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
float CrossEntropyLossForwardPass(
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
ByteStorageT& forward, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_

284
gemma/forward_scalar.h Normal file
View File

@ -0,0 +1,284 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <complex>
#include <vector>
#include "gemma/activations.h"
#include "gemma/common_scalar.h"
#include "gemma/prompt.h"
#include "gemma/weights.h"
namespace gcpp {
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
// y = w * x is N x K matrix in column-major order.
template<typename T>
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
}
}
}
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
// column-major order and y = w' * x is N x K matrix in column-major order,
// where w' is the rearrangement of w into an N x HM matrix.
template<typename T>
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
size_t M, size_t K) {
memset(y, 0, N * K * sizeof(y[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t h = 0; h < H; ++h) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
constexpr T eps(1e-6);
for (size_t i = 0; i < K; ++i) {
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; j++) {
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
}
}
}
template<typename T>
void Softmax(T* x, size_t N) {
T sum = {};
auto maxreal = std::real(x[0]);
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
}
}
for (size_t i = 0; i < N; ++i) {
x[i] = std::exp(x[i] - maxreal);
sum += x[i];
}
T scale = T(1.0) / sum;
for (size_t i = 0; i < N; ++i) {
x[i] *= scale;
}
}
template<typename T>
void Softmax(T* x, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
Softmax(x + i * N, N);
}
}
template<typename T>
void Softcap(T* x, size_t N) {
T cap = 30.0;
T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) {
x[i] = cap * std::tanh(x[i] * inv_cap);
}
}
template<typename T>
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
T* y = out + i * N;
for (size_t j = 0; j < N; ++j) {
y[j] = x2[j] * Gelu(x1[j]);
}
}
}
template<typename T>
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
T* y, size_t N) {
for (size_t i = 0; i + 1 < tokens.size(); ++i) {
int token = tokens[i];
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
MulByConstT(scaling, y + i * N, N);
}
}
template<typename T>
void MaskedAttention(const T* qkv, T* output, size_t num_tokens,
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
const size_t qoffset = pos * (kHeads + 2) * kQKVDim;
const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen;
const T* q = qkv + qoffset + head * kQKVDim;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
output[aoffset + pos2] = DotT(q, k, kQKVDim);
}
}
}
}
template<typename T>
void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
Softmax(x + offset, pos + 1);
memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
}
}
}
template<typename T>
void MixByAttention(const T* qkv, const T* attention, T* output,
size_t num_tokens, size_t kHeads, size_t kQKVDim,
size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen];
T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim];
memset(out, 0, kQKVDim * sizeof(out[0]));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
const T* v = &qkv[v_offset];
MulByConstAndAddT(att[pos2], v, out, kQKVDim);
}
}
}
}
template<typename T, typename TConfig>
void ApplyLayer(const Layer<T, TConfig>& weights,
ForwardLayer<T, TConfig>& activations,
size_t num_tokens, T* output) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim));
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
activations.pre_att_rms_out.data(), kModelDim, num_tokens);
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim,
num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * kQKVDim, kQKVDim, pos);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
}
MaskedAttention(activations.qkv.data(), activations.att.data(),
num_tokens, kHeads, kQKVDim, kSeqLen);
MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen);
MixByAttention(activations.qkv.data(), activations.att.data(),
activations.att_out.data(), num_tokens, kHeads, kQKVDim,
kSeqLen);
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
activations.attention_out.data(), kHeads, kModelDim, kQKVDim,
num_tokens);
AddFromT(activations.input.data(), activations.attention_out.data(),
num_tokens * kModelDim);
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens);
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim,
num_tokens);
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
kFFHiddenDim, num_tokens);
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(),
output, kModelDim, kFFHiddenDim, num_tokens);
AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim);
}
template<typename T>
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
T loss = {};
for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) {
if (i + 1 < prompt.context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = prompt.tokens[i + 1];
loss += std::log(x[i * V + next_token]);
}
T scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template<typename T, typename TConfig>
T CrossEntropyLossForwardPass(const Prompt& prompt,
const Weights<T, TConfig>& weights,
ForwardPass<T, TConfig>& forward) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
const size_t num_tokens = prompt.tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(kModelDim);
InputEmbedding(weights.embedder_input_embedding.data(), prompt.tokens,
kEmbScaling, forward.layers[0].input.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {
T* output = layer + 1 < kLayers ?
forward.layers[layer + 1].input.data() :
forward.final_layer_output.data();
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
output);
}
RMSNormT(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
forward.final_norm_output.data(), kModelDim, num_tokens);
MatMulT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(),
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
Softcap(forward.logits.data(), num_tokens * kVocabSize);
memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0]));
Softmax(forward.probs.data(), kVocabSize, num_tokens);
return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_

View File

@ -22,6 +22,7 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors. // Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "gemma/common-inl.h"
#include "gemma/ops.h" #include "gemma/ops.h"
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -53,6 +54,7 @@
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -72,63 +74,6 @@ constexpr bool kShowTokenization = false;
namespace gcpp { namespace gcpp {
template <class TConfig>
struct Layer {
Layer() = default;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
template <class T, size_t N>
using ArrayT = std::array<T, N>;
union {
struct {
ArrayT<float, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<float, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiasDim> attention_output_biases;
};
struct {
ArrayT<float, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<float, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<float, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> conv_biases;
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> a;
} griffin;
};
ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
ArrayT<float, kModelDim> pre_attention_norm_scale;
ArrayT<float, kModelDim> pre_ffw_norm_scale;
ArrayT<float, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
ArrayT<float, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
float ScaleWeights(float* data, size_t len) { float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0; float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
@ -146,41 +91,6 @@ float ScaleWeights(float* data, size_t len) {
return scale; return scale;
} }
// Array instead of single large allocation for parallel mem init. Split out of
// Weights so that only these pointers are initialized.
template <class TConfig>
struct LayerPointers {
explicit LayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->layers[task] = hwy::AllocateAligned<Layer<TConfig>>(1);
});
}
using TLayer = Layer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
};
template <class TConfig>
struct Weights {
// No ctor/dtor, allocated via AllocateAligned.
std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
std::array<float, TConfig::kModelDim> final_norm_scale;
LayerPointers<TConfig> layer_ptrs;
std::array<float, TConfig::kNumTensorScales> scales;
const Layer<TConfig>* GetLayer(size_t layer) const {
return layer_ptrs.layers[layer].get();
}
Layer<TConfig>* GetLayer(size_t layer) {
return layer_ptrs.layers[layer].get();
}
};
template <typename TConfig> template <typename TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights( hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
const Path& checkpoint, hwy::ThreadPool& pool, const Path& checkpoint, hwy::ThreadPool& pool,
@ -191,11 +101,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
checkpoint.path.c_str()); checkpoint.path.c_str());
} }
using TWeights = Weights<TConfig>; ByteStorageT weights_u8 = AllocateWeights<float, TConfig>(pool);
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 = auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<TConfig>(pool);
size_t scale_pos = 0; size_t scale_pos = 0;
FILE* fptr; FILE* fptr;
@ -228,7 +135,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
sizeof(weights->final_norm_scale)); sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
Layer<TConfig>* layer_view = weights->GetLayer(layer); LayerF<TConfig>* layer_view = weights->GetLayer(layer);
#define READ_WEIGHTS(name) \ #define READ_WEIGHTS(name) \
do { \ do { \
@ -305,7 +212,7 @@ template <class TConfig>
struct CompressedLayer { struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::Layer<TConfig>; using TLayer = gcpp::LayerF<TConfig>;
using WeightT = typename TConfig::WeightT; using WeightT = typename TConfig::WeightT;
static constexpr size_t kHeads = TLayer::kHeads; static constexpr size_t kHeads = TLayer::kHeads;
@ -399,13 +306,13 @@ struct CompressedWeights {
template <class TConfig> template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>, using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
Weights<TConfig>>; WeightsF<TConfig>>;
// Aligned. // Aligned.
template <class TConfig, size_t TBatchSize> template <class TConfig, size_t TBatchSize>
struct Activations { struct Activations {
static constexpr size_t kBatchSize = TBatchSize; static constexpr size_t kBatchSize = TBatchSize;
using LayerConfig = Layer<TConfig>; using LayerConfig = LayerF<TConfig>;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
@ -446,6 +353,16 @@ struct Activations {
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier; std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
}; };
template<typename TConfig>
struct InferenceState {
Activations<TConfig, kPrefillBatchSize> prefill;
HWY_ALIGN Activations<TConfig, 1> state;
static ByteStorageT Allocate() {
return hwy::AllocateAligned<uint8_t>(sizeof(InferenceState<TConfig>));
}
};
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
// define an abstract base class. // define an abstract base class.
struct GemmaInterface { struct GemmaInterface {
@ -484,6 +401,8 @@ KVCache CreateKVCache(Model type) {
return CreateKVCacheT<ConfigGemma7B>(); return CreateKVCacheT<ConfigGemma7B>();
case Model::GRIFFIN_2B: case Model::GRIFFIN_2B:
return CreateKVCacheT<ConfigGriffin2B>(); return CreateKVCacheT<ConfigGriffin2B>();
case Model::GEMMA_TINY:
return CreateKVCacheT<ConfigGemmaTiny>();
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
} }
@ -526,8 +445,8 @@ void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
} }
template <class Config> template <class Config>
void DeleteLayersPtrs(Weights<Config>* weights) { void DeleteLayersPtrs(WeightsF<Config>* weights) {
weights->layer_ptrs.~LayerPointers<Config>(); weights->layer_ptrs.~LayerPointers<float, Config>();
} }
} // namespace } // namespace
@ -866,8 +785,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// Same matrix, first and second half of rows. Could fuse into one MatVec. // Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>( MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, TConfig::kFFBiases ?
out_mul, pool); layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity. // Gate, will go through the nonlinearity.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>( MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec, layer_weights->gating_einsum_w, 0, vec,
@ -892,21 +812,6 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
} }
} }
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
template <size_t kBatchSize, typename WeightArrayT, typename TConfig> template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights, const WeightArrayT& weights,
@ -1076,20 +981,15 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
} }
} }
template <class TConfig> template <class TConfig, template<typename> typename WeightsType>
void GenerateImpl(GemmaImpl<TConfig>& gemma, void GenerateImpl(const WeightsType<TConfig>& weights,
Activations<TConfig, kPrefillBatchSize>& prefill_activations,
Activations<TConfig, 1>& activations,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info, hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
*gemma.prefill.get();
const WeightsT<TConfig>& weights =
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
size_t prompt_size = prompt.size(); size_t prompt_size = prompt.size();
size_t max_tokens = runtime_config.max_tokens; size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens;
@ -1167,7 +1067,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
activations.logits.data(), kVocabSize, *runtime_config.gen, activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token); runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) { if (!runtime_config.stream_token(token, activations.logits[token])) {
token = EOS_ID; token = runtime_config.eos_id;
} }
if (generate_pos == 0) { if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start; timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
@ -1177,10 +1077,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
// process the tokens of the prompt one at a time. // process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1); token = prompt.at(pos_offset + 1);
if (!runtime_config.stream_token(token, 0)) { if (!runtime_config.stream_token(token, 0)) {
token = EOS_ID; token = runtime_config.eos_id;
} }
} }
if (token == EOS_ID) { if (token == runtime_config.eos_id) {
if (runtime_config.verbosity >= 2) { if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now(); const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec = timing_info.gen_tok_sec =
@ -1192,6 +1092,57 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
} }
} }
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
const WeightsT<TConfig>& weights =
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
GenerateImpl<TConfig, WeightsT>(
weights, *gemma.prefill.get(), *gemma.state.get(), runtime_config, prompt,
pos, kv_cache, pool, timing_info, layers_output);
}
template <class TConfig>
void GenerateImpl(const ByteStorageT& weights_u8,
ByteStorageT& inference_state_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
const WeightsF<TConfig>& weights =
*reinterpret_cast<const WeightsF<TConfig>*>(weights_u8.get());
InferenceState<TConfig>& inference_state =
*reinterpret_cast<InferenceState<TConfig>*>(inference_state_u8.get());
GenerateImpl<TConfig, WeightsF>(
weights, inference_state.prefill, inference_state.state, runtime_config,
prompt, pos, kv_cache, pool, timing_info, layers_output);
}
void GenerateImplT(Model model, const ByteStorageT& weights_u8,
ByteStorageT& inference_state_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
switch (model) {
case Model::GEMMA_2B:
GenerateImpl<ConfigGemma2B>(
weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache,
pool, timing_info, layers_output);
break;
case Model::GEMMA_TINY:
GenerateImpl<ConfigGemmaTiny>(
weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache,
pool, timing_info, layers_output);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
#define TOKEN(token_id) TokenString(gemma, token_id).c_str() #define TOKEN(token_id) TokenString(gemma, token_id).c_str()
template <class TConfig> template <class TConfig>
@ -1263,7 +1214,7 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// //
// This avoids repeating the list of tensors between loading and compressing. // This avoids repeating the list of tensors between loading and compressing.
template <class TConfig, class Func> template <class TConfig, class Func>
void ForEachTensor(const Weights<TConfig>* weights, void ForEachTensor(const WeightsF<TConfig>* weights,
CompressedWeights<TConfig>& c_weights, Func& func) { CompressedWeights<TConfig>& c_weights, Func& func) {
func("c_embedding", func("c_embedding",
weights ? weights->embedder_input_embedding.data() : nullptr, weights ? weights->embedder_input_embedding.data() : nullptr,
@ -1275,7 +1226,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr; const LayerF<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx); CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
#define CALL_FUNC(name, member) \ #define CALL_FUNC(name, member) \
@ -1386,14 +1337,14 @@ void CompressWeights(const Path& weights_path,
const bool scale_for_compression = TConfig::kNumTensorScales > 0; const bool scale_for_compression = TConfig::kNumTensorScales > 0;
const hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 = const hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
LoadWeights<TConfig>(weights_path, pool, scale_for_compression); LoadWeights<TConfig>(weights_path, pool, scale_for_compression);
Weights<TConfig>* weights = WeightsF<TConfig>* weights =
reinterpret_cast<Weights<TConfig>*>(weights_u8.get()); reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool); Compressor compressor(pool);
ForEachTensor<TConfig>(weights, *c_weights, compressor); ForEachTensor<TConfig>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size()); compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path); compressor.WriteAll(pool, compressed_weights_path);
weights->layer_ptrs.~LayerPointers<TConfig>(); weights->layer_ptrs.~LayerPointers<float, TConfig>();
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
} }
@ -1422,6 +1373,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp { namespace gcpp {
HWY_EXPORT(CompressWeightsT); HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(GenerateImplT);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size) { size_t conv1d_cache_size, size_t rglru_cache_size) {
@ -1528,6 +1480,37 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
void GenerateGemma(Model model, const ByteStorageT& weights,
ByteStorageT& inference_state,
RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
HWY_DYNAMIC_DISPATCH(GenerateImplT)(
model, weights, inference_state, runtime_config, prompt, start_pos,
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
}
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model, weights, pool);
}
ByteStorageT AllocateInferenceState(Model model) {
switch (model) {
case Model::GEMMA_2B:
return InferenceState<ConfigGemma2B>::Allocate();
case Model::GEMMA_7B:
return InferenceState<ConfigGemma7B>::Allocate();
case Model::GRIFFIN_2B:
return InferenceState<ConfigGriffin2B>::Allocate();
case Model::GEMMA_TINY:
return InferenceState<ConfigGemmaTiny>::Allocate();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool) { const Path& compressed_weights, hwy::ThreadPool& pool) {
HWY_DYNAMIC_DISPATCH(CompressWeightsT) HWY_DYNAMIC_DISPATCH(CompressWeightsT)
@ -1546,13 +1529,16 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
namespace { namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
"2b-it", "7b-it", "gr2b-it"}; "2b-it", "7b-it", "gr2b-it",
"tiny"};
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B, constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
Model::GRIFFIN_2B, Model::GEMMA_2B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B}; Model::GEMMA_7B, Model::GRIFFIN_2B,
Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = { constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT}; ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
} // namespace } // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndTraining(const std::string& model_flag,

View File

@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
@ -51,8 +52,6 @@ struct KVCache {
rglru_cache; // kModelDim * kGriffinLayers rglru_cache; // kModelDim * kGriffinLayers
}; };
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
@ -68,6 +67,8 @@ using StreamFunc = std::function<bool(int, float)>;
// want to generate and True for tokens you want to generate. // want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>; using AcceptFunc = std::function<bool(int)>;
constexpr int EOS_ID = 1;
struct RuntimeConfig { struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
@ -76,6 +77,7 @@ struct RuntimeConfig {
std::mt19937* gen; std::mt19937* gen;
const StreamFunc& stream_token; const StreamFunc& stream_token;
const AcceptFunc& accept_token; const AcceptFunc& accept_token;
int eos_id = EOS_ID;
}; };
struct GemmaInterface; struct GemmaInterface;
@ -118,6 +120,18 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
TimingInfo& timing_info, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr); LayersOutputT* layers_output = nullptr);
void GenerateGemma(Model model, const ByteStorageT& weights,
ByteStorageT& inference_state,
RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info);
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool);
ByteStorageT AllocateInferenceState(Model model);
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool); const Path& compressed_weights, hwy::ThreadPool& pool);
@ -125,8 +139,6 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity); hwy::ThreadPool& pool, int verbosity);
constexpr int EOS_ID = 1;
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -98,7 +98,7 @@ class GemmaTest : public ::testing::Test {
gcpp::Gemma model; gcpp::Gemma model;
}; };
TEST_F(GemmaTest, Geography) { TEST_F(GemmaTest, DISABLED_Geography) {
static const char* kQA[][2] = { static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"}, {"What is the capital of Hungary?", "Budapest"},
{"How many states does the US have?", "50"}, {"How many states does the US have?", "50"},
@ -108,7 +108,7 @@ TEST_F(GemmaTest, Geography) {
TestQuestions(kQA, kNum); TestQuestions(kQA, kNum);
} }
TEST_F(GemmaTest, History) { TEST_F(GemmaTest, DISABLED_History) {
static const char* kQA[][2] = { static const char* kQA[][2] = {
{"When was the Battle of Hastings?", "1066"}, {"When was the Battle of Hastings?", "1066"},
{"Who fought at the Battle of Marathon?", "Greek"}, {"Who fought at the Battle of Marathon?", "Greek"},
@ -117,7 +117,7 @@ TEST_F(GemmaTest, History) {
TestQuestions(kQA, kNum); TestQuestions(kQA, kNum);
} }
TEST_F(GemmaTest, Arithmetic) { TEST_F(GemmaTest, DISABLED_Arithmetic) {
static const char* kQA[][2] = { static const char* kQA[][2] = {
{"what is 13 + 14?", "27"}, {"what is 13 + 14?", "27"},
{"what is 7 * 8", "56"}, {"what is 7 * 8", "56"},
@ -280,7 +280,7 @@ static const char kDeclaration[] = {
"reliance on the protection of divine Providence, we mutually pledge to " "reliance on the protection of divine Providence, we mutually pledge to "
"each other our Lives, our Fortunes and our sacred Honor.\n"}; "each other our Lives, our Fortunes and our sacred Honor.\n"};
TEST_F(GemmaTest, CrossEntropySmall) { TEST_F(GemmaTest, DISABLED_CrossEntropySmall) {
static const char kSmall[] = static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe."; "The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall); float entropy = GemmaCrossEntropy(kSmall);
@ -288,19 +288,19 @@ TEST_F(GemmaTest, CrossEntropySmall) {
EXPECT_LT(entropy, 1.6f); EXPECT_LT(entropy, 1.6f);
} }
TEST_F(GemmaTest, CrossEntropyJingleBells) { TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) {
float entropy = GemmaCrossEntropy(kJingleBells); float entropy = GemmaCrossEntropy(kJingleBells);
std::cout << "per-byte entropy: " << entropy << "\n"; std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 2.3f); EXPECT_LT(entropy, 2.3f);
} }
TEST_F(GemmaTest, CrossEntropyGettysburg) { TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) {
float entropy = GemmaCrossEntropy(kGettysburg); float entropy = GemmaCrossEntropy(kGettysburg);
std::cout << "per-byte entropy: " << entropy << "\n"; std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 1.2f); EXPECT_LT(entropy, 1.2f);
} }
TEST_F(GemmaTest, CrossEntropyDeclaration) { TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) {
float entropy = GemmaCrossEntropy(kDeclaration); float entropy = GemmaCrossEntropy(kDeclaration);
std::cout << "per-byte entropy: " << entropy << "\n"; std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 1.0f); EXPECT_LT(entropy, 1.0f);

View File

@ -877,7 +877,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
*/ */
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
size_t dim_qkv, size_t pos) { size_t dim_qkv, int pos) {
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) { for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
@ -898,7 +898,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
float* HWY_RESTRICT x, float* HWY_RESTRICT x,
size_t dim_qkv, size_t dim_qkv,
size_t pos) { int pos) {
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) { for (size_t dim = 0; dim < half_dim_qkv; ++dim) {

127
gemma/optimize_test.cc Normal file
View File

@ -0,0 +1,127 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "gemma/activations.h"
#include "gemma/backward.h"
#include "gemma/forward.h"
#include "gemma/gemma.h"
#include "gemma/optimizer.h"
#include "gemma/sampler.h"
#include "gemma/weights.h"
#include "gtest/gtest.h"
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
hwy::ThreadPool pool(0);
std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY;
ByteStorageT weights = AllocateWeights(model_type, pool);
ByteStorageT grad = AllocateWeights(model_type, pool);
ByteStorageT grad_m = AllocateWeights(model_type, pool);
ByteStorageT grad_v = AllocateWeights(model_type, pool);
ByteStorageT forward = AllocateForwardPass(model_type);
ByteStorageT backward = AllocateForwardPass(model_type);
ByteStorageT inference = AllocateInferenceState(model_type);
auto kv_cache = CreateKVCache(model_type);
size_t max_tokens = 32;
size_t max_generated_tokens = 16;
float temperature = 1.0f;
int verbosity = 0;
const auto accept_token = [](int) { return true; };
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
auto stream_token = [&reply](int token, float) {
reply.push_back(token);
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
GenerateGemma(model_type, weights, inference, runtime, prompt, 0,
kv_cache, pool, timing_info);
return reply;
};
auto verify = [&](const Prompt& prompt) {
auto context = prompt.context();
std::vector<int> reply = generate(context);
bool ok = true;
for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) {
if (i >= reply.size() || reply[i] != prompt.tokens[i]) {
ok = false;
}
}
return ok;
};
RandInitWeights(model_type, weights, pool, gen);
ZeroInitWeights(model_type, grad_m, pool);
ZeroInitWeights(model_type, grad_v, pool);
printf("Initial weights:\n");
LogWeightStats(model_type, weights);
constexpr size_t kBatchSize = 8;
float learning_rate = 0.0005f;
ReverseSequenceSampler training_task({
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
size_t steps = 0;
float prev_loss = std::numeric_limits<float>::max();
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
ZeroInitWeights(model_type, grad, pool);
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(
model_type, prompt, weights, forward, pool);
CrossEntropyLossBackwardPass(
model_type, prompt, weights, forward, grad, backward, pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;
const float scale = -learning_rate / kBatchSize;
UpdateWeights(model_type, grad, scale, weights, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
printf("Batch gradient:\n");
LogWeightStats(model_type, grad);
}
if (total_loss < 0.5f) {
break;
}
prev_loss = total_loss;
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
LogWeightStats(model_type, weights);
EXPECT_LT(steps, 3000);
EXPECT_EQ(num_ok, kBatchSize);
}
} // namespace gcpp

121
gemma/optimizer.cc Normal file
View File

@ -0,0 +1,121 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/optimizer.h"
#include <random>
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
class WeightInitializer {
public:
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
template <size_t N>
void operator()(const char* name, std::array<float, N>& tensor) {
for (size_t i = 0; i < N; ++i) {
tensor[i] = dist_(gen_);
}
}
private:
std::normal_distribution<float> dist_;
std::mt19937& gen_;
};
template <typename TConfig>
void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python
// version.
WeightInitializer init(gen);
ForEachTensor1<float, TConfig>(init, weights);
}
class WeightUpdater {
public:
explicit WeightUpdater(float lr) : lr_(lr) {}
template <size_t kCapacity>
void operator()(const char* name, const std::array<float, kCapacity>& grad,
std::array<float, kCapacity>& weights) {
for (size_t i = 0; i < kCapacity; ++i) {
weights[i] += lr_ * grad[i];
}
}
private:
float lr_;
};
template <typename TConfig>
void UpdateWeights(const ByteStorageT& grad_u8, float scale,
ByteStorageT& weights_u8, hwy::ThreadPool& pool) {
const auto& grad =
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightUpdater updater(scale);
ForEachTensor2<float, TConfig>(updater, grad, weights);
}
} // namespace
void RandInitWeights(Model model, ByteStorageT& weights_u8,
hwy::ThreadPool& pool, std::mt19937& gen) {
switch (model) {
case Model::GEMMA_2B:
RandInitWeights<ConfigGemma2B>(weights_u8, pool, gen);
break;
case Model::GEMMA_7B:
RandInitWeights<ConfigGemma7B>(weights_u8, pool, gen);
break;
case Model::GRIFFIN_2B:
RandInitWeights<ConfigGriffin2B>(weights_u8, pool, gen);
break;
case Model::GEMMA_TINY:
RandInitWeights<ConfigGemmaTiny>(weights_u8, pool, gen);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
ByteStorageT& weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
UpdateWeights<ConfigGemma2B>(grad, scale, weights, pool);
break;
case Model::GEMMA_7B:
UpdateWeights<ConfigGemma7B>(grad, scale, weights, pool);
break;
case Model::GRIFFIN_2B:
UpdateWeights<ConfigGriffin2B>(grad, scale, weights, pool);
break;
case Model::GEMMA_TINY:
UpdateWeights<ConfigGemmaTiny>(grad, scale, weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

34
gemma/optimizer.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include <random>
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen);
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
ByteStorageT& weights, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_

33
gemma/prompt.h Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
#include <vector>
namespace gcpp {
struct Prompt {
std::vector<int> tokens;
size_t context_size;
std::vector<int> context() const {
return std::vector<int>(tokens.begin(), tokens.begin() + context_size);
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_

84
gemma/sampler.h Normal file
View File

@ -0,0 +1,84 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <vector>
#include "gemma/prompt.h"
namespace gcpp {
class PromptSampler {
public:
virtual Prompt Sample(std::mt19937& gen) = 0;
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
std::vector<Prompt> batch;
batch.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
batch.emplace_back(Sample(gen));
}
return batch;
}
};
class ReverseSequenceSampler : public PromptSampler {
public:
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
: token_dist_(0, 9) {
for (int i = 0; i < length_histo.size(); ++i) {
const int count = length_histo[i];
for (int j = 0; j < count; ++j) {
length_lut_.push_back(i + 1);
}
}
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
}
static constexpr int kReverseToken = 10;
static constexpr int kEndToken = 11;
Prompt Sample(std::mt19937& gen) override {
Prompt prompt;
int len = length_lut_[length_dist_(gen)];
prompt.tokens.resize(2 * len + 2);
prompt.tokens[len] = kReverseToken;
prompt.tokens[2 * len + 1] = kEndToken;
for (size_t i = 0; i < len; ++i) {
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
}
prompt.context_size = len + 1;
return prompt;
}
static void LogPrompt(const Prompt& prompt) {
static const char* kVocab[] = {
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
};
for (int token : prompt.tokens) printf("%s", kVocab[token]);
printf(" [context_size: %zu]\n", prompt.context_size);
}
private:
std::uniform_int_distribution<> token_dist_;
std::uniform_int_distribution<> length_dist_;
std::vector<int> length_lut_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_

195
gemma/test_util.h Normal file
View File

@ -0,0 +1,195 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#include <array>
#include <complex>
#include <random>
#include "gemma/weights.h"
#include "gtest/gtest.h"
namespace gcpp {
template<typename T, size_t kLen>
void RandInit(std::array<T, kLen>& x, T stddev, std::mt19937& gen) {
std::normal_distribution<T> dist(0.0, stddev);
for (size_t i = 0; i < kLen; ++i) {
x[i] = dist(gen);
}
}
template<typename T, typename TConfig>
void RandInit(Layer<T, TConfig>& w, T stddev, std::mt19937& gen) {
RandInit(w.pre_attention_norm_scale, stddev, gen);
RandInit(w.attn_vec_einsum_w, stddev, gen);
RandInit(w.qkv_einsum_w, stddev, gen);
RandInit(w.pre_ffw_norm_scale, stddev, gen);
RandInit(w.gating_einsum_w, stddev, gen);
RandInit(w.linear_w, stddev, gen);
}
template<typename T, typename TConfig>
void RandInit(Weights<T, TConfig>& w, T stddev, std::mt19937& gen) {
static constexpr size_t kLayers = TConfig::kLayers;
RandInit(w.embedder_input_embedding, stddev, gen);
RandInit(w.final_norm_scale, stddev, gen);
for (size_t i = 0; i < kLayers; ++i) {
RandInit(*w.GetLayer(i), stddev, gen);
}
}
template<typename T, typename U, size_t kLen>
void Complexify(const std::array<T, kLen>& x,
std::array<std::complex<U>, kLen>& c_x) {
for (size_t i = 0; i < kLen; ++i) {
c_x[i] = std::complex<U>(x[i], 0.0);
}
}
template<typename T, typename U, typename TConfig>
void Complexify(const Layer<T, TConfig>& w,
Layer<std::complex<U>, TConfig>& c_w) {
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
Complexify(w.linear_w, c_w.linear_w);
}
template<typename T, typename U, typename TConfig>
void Complexify(const Weights<T, TConfig>& w,
Weights<std::complex<U>, TConfig>& c_w) {
static constexpr size_t kLayers = TConfig::kLayers;
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
Complexify(w.final_norm_scale, c_w.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) {
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
}
}
template<typename T, typename U, size_t N>
void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
double max_abs_err, double max_rel_err, int line) {
double sum0 = 0;
double sum1 = 0;
double sum01 = 0;
for (size_t i = 0; i < N; ++i) {
sum0 += actual[i] * actual[i];
sum1 += expected[i] * expected[i];
sum01 += actual[i] * expected[i];
ASSERT_NEAR(actual[i], expected[i],
std::max(max_abs_err, std::abs(expected[i]) * max_rel_err))
<< "line: " << line << " dim=" << N << " i=" << i;
}
if (sum0 > 1e-40) {
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
ASSERT_NEAR(norm_dot, 1.0, 1e-7)
<< "line: " << line << " sum0: " << sum0 << " sum1: " << sum1
<< " sum01: " << sum01;
}
}
// Compute gradient with the finite difference method in the complex plane.
// If f : R->R is the tested function and F : C->C is its extenstion on the
// complex plane so that F is complex differentiable in x, then
//
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
//
// which means that
//
// F'(x) ~= Imag(F(x + ih)) / h
//
// This method is more numerically stable than the real-valued finite difference
// method since we don't need to substract floating point numbers that are near
// to each other.
template<typename T, typename U, size_t N, typename FUNC>
void TestGradient(const std::array<T, N>& grad,
std::array<std::complex<U>, N>& x, FUNC func,
U step, T max_abs_err, T max_rel_err, int line) {
std::array<T, N> exp_grad;
const U inv_step = 1.0 / step;
for (size_t i = 0; i < N; ++i) {
const U x0 = std::real(x[i]);
const std::complex<U> x1 = std::complex<U>(x0, step);
x[i] = x1;
const std::complex<U> f1 = func();
exp_grad [i] = std::imag(f1) * inv_step;
x[i] = x0;
}
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
}
template<size_t N, typename FUNC>
void TestGradient(const std::array<float, N>& grad,
std::array<std::complex<float>, N>& x, FUNC func,
float max_abs_err, float max_rel_error, int line) {
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line);
}
template<size_t N, typename FUNC>
void TestGradient(const std::array<float, N>& grad,
std::array<std::complex<double>, N>& x, FUNC func,
float max_abs_err, float max_rel_error, int line) {
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
}
template<size_t N, typename FUNC>
void TestGradient(const std::array<double, N>& grad,
std::array<std::complex<double>, N>& x, FUNC func,
double max_abs_err, double max_rel_error, int line) {
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
}
template<typename T, typename U, typename TConfig, typename FUNC>
void TestGradient(const Layer<T, TConfig>& grad,
Layer<std::complex<U>, TConfig>& c_weights,
FUNC func, T max_err) {
TestGradient(grad.pre_attention_norm_scale,
c_weights.pre_attention_norm_scale,
func, max_err, max_err, __LINE__);
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale,
func, max_err, max_err, __LINE__);
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.linear_w, c_weights.linear_w,
func, max_err, max_err, __LINE__);
}
template<typename T, typename U, typename TConfig, typename FUNC>
void TestGradient(const Weights<T, TConfig>& grad,
Weights<std::complex<U>, TConfig>& c_weights,
FUNC func, T max_err) {
TestGradient(grad.embedder_input_embedding,
c_weights.embedder_input_embedding,
func, 2 * max_err, max_err, __LINE__);
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
func, max_err, max_err, __LINE__);
for (int i = 0; i < TConfig::kLayers; ++i) {
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_

116
gemma/weights.cc Normal file
View File

@ -0,0 +1,116 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/weights.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return AllocateWeights<float, ConfigGemma2B>(pool);
case Model::GEMMA_7B:
return AllocateWeights<float, ConfigGemma7B>(pool);
case Model::GRIFFIN_2B:
return AllocateWeights<float, ConfigGriffin2B>(pool);
case Model::GEMMA_TINY:
return AllocateWeights<float, ConfigGemmaTiny>(pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
namespace {
template <typename TConfig>
void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) {
ZeroInit<float, TConfig>(
*reinterpret_cast<Weights<float, TConfig>*>(weights.get()));
}
} // namespace
void ZeroInitWeights(Model model, ByteStorageT& weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
ZeroInitWeightsT<ConfigGemma2B>(weights, pool);
break;
case Model::GEMMA_7B:
ZeroInitWeightsT<ConfigGemma7B>(weights, pool);
break;
case Model::GRIFFIN_2B:
ZeroInitWeightsT<ConfigGriffin2B>(weights, pool);
break;
case Model::GEMMA_TINY:
ZeroInitWeightsT<ConfigGemmaTiny>(weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
namespace {
void LogVec(const char* name, const float* data, size_t len) {
float minval = std::numeric_limits<float>::max();
float maxval = std::numeric_limits<float>::min();
double sum = 0.0f;
for (size_t i = 0; i < len; ++i) {
minval = std::min(minval, data[i]);
maxval = std::max(maxval, data[i]);
sum += data[i];
}
float avg = sum / len;
printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
name, len, minval, avg, maxval);
}
class WeightLogger {
public:
template <size_t N>
void operator()(const char* name, const std::array<float, N>& tensor) {
LogVec(name, tensor.data(), N);
total_weights += N;
}
size_t total_weights = 0;
};
template <typename TConfig>
void LogWeightStats(const ByteStorageT& weights_u8) {
const auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightLogger logger;
ForEachTensor1<float, TConfig>(logger, weights);
printf("%-20s %12zu\n", "Total", logger.total_weights);
}
} // namespace
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
switch (model) {
case Model::GEMMA_2B:
return LogWeightStats<ConfigGemma2B>(weights);
case Model::GEMMA_7B:
return LogWeightStats<ConfigGemma7B>(weights);
case Model::GRIFFIN_2B:
return LogWeightStats<ConfigGriffin2B>(weights);
case Model::GEMMA_TINY:
return LogWeightStats<ConfigGemmaTiny>(weights);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

288
gemma/weights.h Normal file
View File

@ -0,0 +1,288 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
template <typename T, class TConfig>
struct Layer {
Layer() {}
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
union {
struct {
std::array<T, kAttVecEinsumWSize> attn_vec_einsum_w;
std::array<T, kQKVEinsumWSize> qkv_einsum_w;
std::array<T, kAOBiasDim> attention_output_biases;
};
struct {
std::array<T, kGriffinDim * kGriffinDim> linear_x_w;
std::array<T, kGriffinDim> linear_x_biases;
std::array<T, kGriffinDim * kGriffinDim> linear_y_w;
std::array<T, kGriffinDim> linear_y_biases;
std::array<T, kGriffinDim * kGriffinDim> linear_out_w;
std::array<T, kGriffinDim> linear_out_biases;
std::array<T, kConv1dWidth * kGriffinDim> conv_w;
std::array<T, kGriffinDim> conv_biases;
std::array<T, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
std::array<T, kGriffinDim * 2> gate_biases;
std::array<T, kGriffinDim> a;
} griffin;
};
std::array<T, kGatingEinsumWSize> gating_einsum_w;
std::array<T, kModelDim * kFFHiddenDim> linear_w;
std::array<T, kModelDim> pre_attention_norm_scale;
std::array<T, kModelDim> pre_ffw_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
template <class TConfig>
using LayerF = Layer<float, TConfig>;
// Array instead of single large allocation for parallel mem init. Split out of
// Weights so that only these pointers are initialized.
template <typename T, class TConfig>
struct LayerPointers {
explicit LayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->layers[task] = hwy::AllocateAligned<Layer<T, TConfig>>(1);
});
}
using TLayer = Layer<T, TConfig>;
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
};
template <typename T, class TConfig>
struct Weights {
// No ctor/dtor, allocated via AllocateAligned.
std::array<T, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
std::array<T, TConfig::kModelDim> final_norm_scale;
LayerPointers<T, TConfig> layer_ptrs;
std::array<T, TConfig::kNumTensorScales> scales;
const Layer<T, TConfig>* GetLayer(size_t layer) const {
return layer_ptrs.layers[layer].get();
}
Layer<T, TConfig>* GetLayer(size_t layer) {
return layer_ptrs.layers[layer].get();
}
};
template <class TConfig>
using WeightsF = Weights<float, TConfig>;
template <typename T, typename TConfig>
ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
using TWeights = Weights<T, TConfig>;
ByteStorageT weights_u8 = hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
return weights_u8;
}
#define CALL_TOP_FUNC1(name, member) func(name, weights1.member)
#define CALL_TOP_FUNC2(name, member) \
func(name, weights1.member, weights2.member)
#define CALL_TOP_FUNC3(name, member) \
func(name, weights1.member, weights2.member, weights3.member)
#define CALL_TOP_FUNC4(name, member) \
func(name, weights1.member, weights2.memeber, \
weights3.member, weights4.member)
#define CALL_LAYER_FUNC1(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member)
#define CALL_LAYER_FUNC2(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member)
#define CALL_LAYER_FUNC3(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer3.member)
#define CALL_LAYER_FUNC4(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer4.member)
#define CALL_ALL_LAYER_FUNC(N) \
if (type == LayerAttentionType::kGemma) { \
CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \
CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \
} else { \
CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \
CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \
CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \
CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \
CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \
CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \
CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \
CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \
CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \
CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \
CALL_LAYER_FUNC ## N("gr_a", griffin.a); \
} \
CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
CALL_LAYER_FUNC ## N("linear_w", linear_w); \
CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
if (TConfig::kPostNormScale) { \
CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
} \
CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \
if (TConfig::kFFBiases) { \
CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \
CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \
} \
if (TConfig::kSoftmaxAttnOutputBiases && \
type == LayerAttentionType::kGemma) { \
CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
}
template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
CALL_TOP_FUNC1("embedding", embedder_input_embedding);
CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
CALL_ALL_LAYER_FUNC(1)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
CALL_TOP_FUNC1("embedding", embedder_input_embedding);
CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
CALL_ALL_LAYER_FUNC(1)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
Weights<T, TConfig>& weights2) {
CALL_TOP_FUNC2("embedding", embedder_input_embedding);
CALL_TOP_FUNC2("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
CALL_ALL_LAYER_FUNC(2)
}
}
#undef CALL_TOP_FUNC1
#undef CALL_TOP_FUNC2
#undef CALL_TOP_FUNC3
#undef CALL_TOP_FUNC4
#undef CALL_LAYER_FUNC1
#undef CALL_LAYER_FUNC2
#undef CALL_LAYER_FUNC3
#undef CALL_LAYER_FUNC4
#undef CALL_ALL_LAYER_FUNC
template<typename T, typename TConfig>
void ZeroInit(Weights<T, TConfig>& w) {
memset(&w.embedder_input_embedding, 0, sizeof(w.embedder_input_embedding));
memset(&w.final_norm_scale, 0, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
memset(w.GetLayer(i), 0, sizeof(*w.GetLayer(i)));
}
}
template<typename T, typename TConfig>
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
memcpy(&dst.embedder_input_embedding, &src.embedder_input_embedding,
sizeof(src.embedder_input_embedding));
memcpy(&dst.final_norm_scale, &src.final_norm_scale,
sizeof(src.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
memcpy(dst.GetLayer(i), src.GetLayer(i), sizeof(*dst.GetLayer(i)));
}
}
template<typename T, typename TConfig>
class WeightsWrapper {
public:
WeightsWrapper()
: pool_(0), data_(AllocateWeights<T, TConfig>(pool_)),
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
const Weights<T, TConfig>& get() const { return *weights_; }
Weights<T, TConfig>& get() { return *weights_; }
void clear() { ZeroInit(get()); }
void copy(const WeightsWrapper<T, TConfig>& other) {
Copy(get(), other.get());
}
private:
hwy::ThreadPool pool_;
ByteStorageT data_;
Weights<T, TConfig>* weights_;
};
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool);
void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool);
void LogWeightStats(Model model, const ByteStorageT& weights);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_