Merge pull request #203 from szabadka:backprop5

PiperOrigin-RevId: 640133430
This commit is contained in:
Copybara-Service 2024-06-04 06:33:08 -07:00
commit 25d9c8ff30
32 changed files with 4105 additions and 176 deletions

View File

@ -59,8 +59,12 @@ cc_library(
"gemma/gemma.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/common.h",
"gemma/common-inl.h",
"gemma/configs.h",
"gemma/gemma.h",
"gemma/weights.h",
],
deps = [
":ops",

View File

@ -46,10 +46,28 @@ set(SOURCES
compression/sfp.h
compression/sfp-inl.h
compression/test_util.h
backprop/backward.cc
backprop/backward.h
backprop/backward-inl.h
backprop/backward_scalar.h
backprop/common_scalar.cc
backprop/common_scalar.h
backprop/forward.cc
backprop/forward.h
backprop/forward-inl.h
backprop/forward_scalar.h
backprop/optimizer.cc
backprop/optimizer.h
gemma/configs.h
gemma/activations.cc
gemma/activations.h
gemma/common.h
gemma/common-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h
gemma/weights.cc
gemma/weights.h
util/app.h
util/args.h
)
@ -100,7 +118,13 @@ target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohm
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS)
enable_testing()
include(GoogleTest)
set(GEMMA_TEST_FILES
backprop/backward_test.cc
backprop/backward_scalar_test.cc
backprop/optimize_test.cc
gemma/ops_test.cc
gemma/gemma_test.cc
)

417
backprop/backward-inl.h Normal file
View File

@ -0,0 +1,417 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Implementation of the Vector-Jacobian Products (VJP) of the individual
// operations of the forward pass.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/weights.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#endif
#include "gemma/common-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <size_t kCols, size_t kRows>
void MatMulVJP(const std::array<float, kRows * kCols>& weights,
const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens,
std::array<float, kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t voffs = pos * kRows;
const size_t xoffs = pos * kCols;
for (size_t j = 0; j < kRows; ++j) {
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols);
MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs],
kCols);
}
}
}
template <size_t kHeads, size_t kCols, size_t kRows>
void MultiHeadMatMulVJP(
const std::array<float, kHeads * kRows * kCols>& weights,
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens,
std::array<float, kHeads * kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t j = 0; j < kRows; ++j) {
for (size_t h = 0; h < kHeads; ++h) {
MulByConstAndAdd(v[pos * kRows + j],
&x[pos * kHeads * kCols + h * kCols],
&grad_w[h * kRows * kCols + j * kCols], kCols);
MulByConstAndAdd(v[pos * kRows + j],
&weights[h * kRows * kCols + j * kCols],
&grad_x[pos * kHeads * kCols + h * kCols], kCols);
}
}
}
}
template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
// kSqrtOverPi*3*kMul
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
const hn::Vec<D> v2 = hn::Mul(v, v);
const hn::Vec<D> v3 = hn::Mul(v2, v);
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
const hn::Vec<D> tanh = hn::Tanh(d, arg);
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
}
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto offset =
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
hn::Transform1(
d, backward, size, forward,
[&offset](const auto d, const auto v, const auto y)
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
}
static HWY_NOINLINE void RMSNormVJP(
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
constexpr float eps = 1e-6f;
float ss = SquaredL2(x + offset, model_dim);
ss = 1.0f / sqrtf(ss / StaticCast<float>(model_dim) + eps);
for (size_t i = 0; i < model_dim; ++i) {
grad_w[i] += v[offset + i] * x[offset + i] * ss;
}
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
float tmp = 0.0f;
for (size_t i = 0; i < model_dim; ++i) {
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
}
tmp *= ss3;
for (size_t i = 0; i < model_dim; ++i) {
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
tmp * x[offset + i];
}
}
}
static HWY_NOINLINE void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt,
const float scaling, const float* HWY_RESTRICT v,
float* HWY_RESTRICT grad, size_t model_dim) {
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
MulByConstAndAdd(scaling, v + pos * model_dim,
grad + token * model_dim, model_dim);
}
}
template <typename TConfig>
void LayerVJP(const Layer<float, TConfig>& weights,
const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad,
size_t num_tokens,
Layer<float, TConfig>& grad,
ForwardLayer<float, TConfig>& backward,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
MatMulVJP<kFFHiddenDim, kModelDim>(
weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad,
num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(),
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim;
const float* HWY_RESTRICT b_out_gated =
backward.ffw_hidden_gated.data() + pos * kFFHiddenDim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = Load(df, f_out + i);
const auto x = Load(df, f_out_mul + i);
const auto v = Load(df, b_out_gated + i);
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
}
}
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
weights.gating_einsum_w,
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
num_tokens, grad.gating_einsum_w,
backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(),
kModelDim, num_tokens,
grad.pre_ffw_norm_scale.data(),
backward.attention_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * kModelDim,
backward.attention_out.data() + pos * kModelDim, kModelDim);
}
hwy::ZeroBytes(backward.qkv.data(),
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
weights.attn_vec_einsum_w, forward.att_out.data(),
backward.attention_out.data(), num_tokens,
grad.attn_vec_einsum_w, backward.att_out.data(), pool);
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
const float* HWY_RESTRICT b_att_out =
backward.att_out.data() + (pos * kHeads + head) * kQKVDim;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim);
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim);
}
}
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
}
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim);
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim);
}
}
}
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(b_kv, kQKVDim, -pos);
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_q =
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
MulByConst(kQueryScale, b_q, kQKVDim);
Rope(b_q, kQKVDim, -pos);
}
}
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
weights.qkv_einsum_w, forward.pre_att_rms_out.data(),
backward.qkv.data(), num_tokens,
grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(),
forward.input.data(),
backward.pre_att_rms_out.data(),
kModelDim, num_tokens,
grad.pre_attention_norm_scale.data(),
backward.input.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(backward.attention_out.data() + pos * kModelDim,
backward.input.data() + pos * kModelDim, kModelDim);
}
}
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const float cap,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto one = hn::Set(d, 1.0f);
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform1(
d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y);
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
}
static HWY_NOINLINE void CrossEntropyLossGrad(
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
const Prompt& prompt, size_t vocab_size) {
HWY_ASSERT(!prompt.tokens.empty());
const float scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1;
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
if (pos + 1 < prompt.context_size) {
continue;
}
const int next_token = prompt.tokens[pos + 1];
grad[pos * vocab_size + next_token] =
scaling / x[pos * vocab_size + next_token];
}
}
template <typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<float, TConfig>& weights,
const ForwardPass<float, TConfig>& forward,
Weights<float, TConfig>& grad,
ForwardPass<float, TConfig>& backward,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale);
static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(prompt.context_size > 0);
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
const size_t num_tokens = prompt.tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize);
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize,
kVocabSize);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize);
}
MatMulVJP<kModelDim, kVocabSize>(
weights.embedder_input_embedding, forward.final_norm_output.data(),
backward.logits.data(), num_tokens,
grad.embedder_input_embedding, backward.final_norm_output.data(),
pool);
RMSNormVJP(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(),
kModelDim, num_tokens,
grad.final_norm_scale.data(),
backward.final_layer_output.data(), pool);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
auto type = TConfig::kLayerConfig[layer];
// TODO(szabadka) Implement Griffin layer vjp.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

89
backprop/backward.cc Normal file
View File

@ -0,0 +1,89 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/backward.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ByteStorageT& weights_u8,
const ByteStorageT& forward_u8,
ByteStorageT& grad_u8,
ByteStorageT& backward_u8,
hwy::ThreadPool& pool) {
using TWeights = WeightsF<TConfig>;
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool);
}
void CrossEntropyLossBackwardPassT(Model model,
const Prompt& prompt,
const ByteStorageT& weights,
const ByteStorageT& forward,
ByteStorageT& grad,
ByteStorageT& backward,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B>(
prompt, weights, forward, grad, backward, pool);
break;
case Model::GEMMA_TINY:
CrossEntropyLossBackwardPass<ConfigGemmaTiny>(
prompt, weights, forward, grad, backward, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossBackwardPassT);
void CrossEntropyLossBackwardPass(
const Model& model, const Prompt& prompt,
const ByteStorageT& weights, const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
model, prompt, weights, forward, grad, backward, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

34
backprop/backward.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#include <vector>
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void CrossEntropyLossBackwardPass(
const Model& model, const Prompt& prompt,
const ByteStorageT& weights, const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_

353
backprop/backward_scalar.h Normal file
View File

@ -0,0 +1,353 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <complex>
#include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/weights.h"
namespace gcpp {
template<typename T>
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t M, size_t K) {
memset(dx, 0, M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
}
}
}
template<typename T>
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t H, size_t N, size_t M, size_t K) {
memset(dx, 0, H * M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
for (size_t h = 0; h < H; ++h) {
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
&dw[h * N * M + j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
&dx[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const size_t offset = i * N;
constexpr T eps(1e-6);
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; ++j) {
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
}
const T ss3 = ss * ss * ss / T(N);
T tmp = 0.0;
for (size_t j = 0; j < N; ++j) {
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
}
tmp *= ss3;
for (size_t j = 0; j < N; ++j) {
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
}
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += y[i] * dy[i];
}
for (size_t i = 0; i < N; ++i) {
dy[i] = y[i] * (dy[i] - sum);
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
SoftmaxVJPT(y + i * N, dy + i * N, N);
}
}
template<typename T>
T GeluDerivative(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
const T x2 = x * x;
const T x3 = x2 * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T tanh = std::tanh(arg);
const T cdf = T(0.5) * (T(1.0) + tanh);
const T dtanh = T(1.0) - tanh * tanh;
const T darg = kMul2 * x2 + kSqrt2OverPi;
return T(0.5) * x * dtanh * darg + cdf;
}
template<typename T>
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
const T* v = d_out + i * N;
T* dx1 = d_in + i * 2 * N;
T* dx2 = dx1 + N;
for (size_t j = 0; j < N; ++j) {
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
dx2[j] = v[j] * Gelu(x1[j]);
}
}
}
template<typename T>
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
size_t num_tokens, size_t kHeads, size_t kQKVDim,
size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * (kHeads + 2) * kQKVDim;
memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
const T* q = qkv + qoffs;
const T* dout = doutput + aoffs;
T* dq = dqkv + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const T* k = qkv + koffs;
T* dk = dqkv + koffs;
MulByConstAndAddT(dout[pos2], k, dq, kQKVDim);
MulByConstAndAddT(dout[pos2], q, dk, kQKVDim);
}
}
}
}
template<typename T>
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens,
size_t kHeads, size_t kSeqLen) {
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
}
}
}
template<typename T>
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
T* dqkv, T* dattention, size_t num_tokens,
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
auto v_offset = [&](size_t pos) {
return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim;
};
for (size_t pos = 0; pos < num_tokens; ++pos) {
memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim;
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
const T* att = &attention[aoffset];
const T* dout = &doutput[offset];
T* datt = &dattention[aoffset];
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim);
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim);
}
}
}
}
template<typename T>
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
const T* dy, T* dw, size_t N) {
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i];
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
}
}
template<typename T, typename TConfig>
void LayerVJP(const Layer<T, TConfig>& weights,
const ForwardLayer<T, TConfig>& forward,
const T* dy,
Layer<T, TConfig>& grad,
ForwardLayer<T, TConfig>& backward,
size_t num_tokens) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim));
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
kModelDim, kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim,
num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(),
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
kModelDim, num_tokens);
AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim);
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(),
grad.attn_vec_einsum_w.data(),
backward.att_out.data(),
kHeads, kModelDim, kQKVDim, num_tokens);
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
backward.att_out.data(), backward.qkv.data(),
backward.att.data(), num_tokens, kHeads, kQKVDim,
kSeqLen);
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(),
num_tokens, kHeads, kSeqLen);
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
}
for (int pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * kQKVDim, kQKVDim, -pos);
}
}
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), grad.qkv_einsum_w.data(),
backward.pre_att_rms_out.data(),
(kHeads + 2) * kQKVDim, kModelDim, num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(),
grad.pre_attention_norm_scale.data(),
backward.input.data(), kModelDim, num_tokens);
AddFromT(backward.attention_out.data(), backward.input.data(),
num_tokens * kModelDim);
}
template<typename T>
void SoftcapVJPT(const T* y, T* dy, size_t N) {
T cap = 30.0;
T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap;
dy[i] *= (T(1.0) - scaled * scaled);
}
}
template<typename T>
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
T scaling = -1.0 / std::log(2.0);
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
memset(dx, 0, V * num_tokens * sizeof(x[0]));
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) {
continue;
}
const int next_token = tokens[i + 1];
dx[i * V + next_token] = scaling / x[i * V + next_token];
}
}
template<typename T, typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<T, TConfig>& weights,
const ForwardPass<T, TConfig>& forward,
Weights<T, TConfig>& grad,
ForwardPass<T, TConfig>& backward) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize);
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
kVocabSize, num_tokens);
SoftcapVJPT(forward.logits.data(), backward.logits.data(),
num_tokens * kVocabSize);
MatMulVJPT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(),
backward.logits.data(),
grad.embedder_input_embedding.data(),
backward.final_norm_output.data(),
kVocabSize, kModelDim, num_tokens);
RMSNormVJPT(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(),
grad.final_norm_scale.data(),
backward.final_layer_output.data(), kModelDim, num_tokens);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
T* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
}
const T kEmbScaling = EmbeddingScaling(kModelDim);
InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
tokens, kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_

View File

@ -0,0 +1,610 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/backward_scalar.h"
#include <array>
#include <complex>
#include <random>
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gtest/gtest.h"
namespace gcpp {
TEST(BackPropTest, MatMulVJP) {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kRows * kCols> weights;
std::array<T, kTokens * kCols> x;
std::array<T, kRows * kCols> grad;
std::array<T, kTokens * kCols> dx;
std::array<TC, kRows * kCols> c_weights;
std::array<TC, kTokens * kCols> c_x;
std::array<TC, kTokens * kRows> c_y;
std::array<T, kTokens * kRows> dy;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
memset(&grad, 0, sizeof(grad));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__);
TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__);
}
}
TEST(BackPropTest, MultiHeadMatMulVJP) {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kRows * kCols * kHeads> weights;
std::array<T, kTokens * kCols * kHeads> x;
std::array<T, kRows * kCols * kHeads> grad;
std::array<T, kTokens * kCols * kHeads> dx;
std::array<TC, kRows * kCols * kHeads> c_weights;
std::array<TC, kTokens * kCols * kHeads> c_x;
std::array<TC, kTokens * kRows> c_y;
std::array<T, kTokens * kRows> dy;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
memset(&grad, 0, sizeof(grad));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
dx.data(), kHeads, kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__);
}
}
TEST(BackPropTest, RMSNormVJP) {
static const size_t K = 2;
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> weights;
std::array<T, N> grad;
std::array<T, K * N> x;
std::array<T, K * N> dx;
std::array<T, K * N> dy;
std::array<TC, N> c_weights;
std::array<TC, K * N> c_x;
std::array<TC, K * N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
};
memset(&grad, 0, sizeof(grad));
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
}
}
TEST(BackPropTest, SoftmaxVJP) {
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> x;
std::array<T, N> dx;
std::array<T, N> dy;
std::array<TC, N> c_x;
std::array<TC, N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), sizeof(c_x));
Softmax(c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softmax(x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftmaxVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MaskedSoftmaxVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kTokens = 14;
static const size_t N = kHeads * kSeqLen * kSeqLen;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> x;
std::array<T, N> dy;
std::array<T, N> dx = {};
std::array<TC, N> c_x;
std::array<TC, N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(),
kTokens * kHeads * kSeqLen * sizeof(c_x[0]));
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
return DotT(dy.data(), c_y.data(), N);
};
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx[0]));
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, SoftcapVJP) {
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, N> x;
std::array<T, N> dx;
std::array<T, N> dy;
std::array<TC, N> c_x;
std::array<TC, N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
Softcap(c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
};
Softcap(x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftcapVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, CrossEntropyLossGrad) {
static const size_t K = 8;
static const size_t V = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, K * V> x;
std::array<T, K * V> dx;
std::array<TC, K * V> c_x;
Prompt prompt;
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
for (int iter = 0; iter < 10; ++iter) {
prompt.context_size = 1 + (iter % 6);
RandInit(x, 1.0 * (1 << iter), gen);
Softcap(x.data(), V * K);
Softmax(x.data(), V, K);
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
Complexify(x, c_x);
auto func = [&]() {
return CrossEntropyLoss(c_x.data(), prompt, V);
};
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
}
}
TEST(BackPropTest, GatedGeluVJP) {
static const size_t K = 2;
static const size_t N = 64;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, K * 2 * N> x;
std::array<T, K * 2 * N> dx;
std::array<T, K * N> dy;
std::array<TC, K * 2 * N> c_x;
std::array<TC, K * N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
GatedGelu(c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), N * K);
};
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MaskedAttentionVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kQKVDim = 8;
static const size_t kTokens = 14;
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
static const size_t kOutSize = kSeqLen * kHeads * kSeqLen;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kQKVSize> x;
std::array<T, kQKVSize> dx = {};
std::array<T, kOutSize> dy;
std::array<TC, kQKVSize> c_x;
std::array<TC, kOutSize> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
};
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, MixByAttentionVJP) {
static const size_t kSeqLen = 16;
static const size_t kHeads = 2;
static const size_t kQKVDim = 8;
static const size_t kTokens = 14;
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen;
static const size_t kOutSize = kSeqLen * kHeads * kQKVDim;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kQKVSize> qkv;
std::array<T, kQKVSize> dqkv = {};
std::array<T, kAttnSize> attn;
std::array<T, kAttnSize> dattn = {};
std::array<T, kOutSize> dy;
std::array<TC, kQKVSize> c_qkv;
std::array<TC, kAttnSize> c_attn;
std::array<TC, kOutSize> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(qkv, 1.0, gen);
RandInit(attn, 1.0, gen);
Complexify(qkv, c_qkv);
Complexify(attn, c_attn);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
};
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
}
}
TEST(BackPropTest, InputEmbeddingVJP) {
static const size_t kSeqLen = 8;
static const size_t kVocabSize = 4;
static const size_t kModelDim = 16;
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
std::array<T, kVocabSize * kModelDim> weights;
std::array<T, kVocabSize * kModelDim> grad;
std::array<T, kSeqLen * kModelDim> dy;
std::array<TC, kVocabSize * kModelDim> c_weights;
std::array<TC, kSeqLen * kModelDim> c_y;
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
size_t num_tokens = tokens.size() - 1;
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
auto func = [&]() {
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
};
memset(&grad, 0, sizeof(grad));
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
kModelDim);
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
}
}
struct TestConfig {
static constexpr int kSeqLen = 18;
static constexpr int kVocabSize = 12;
static constexpr int kModelDim = 32;
static constexpr int kHeads = 3;
static constexpr int kQKVDim = 12;
static constexpr int kFFHiddenDim = 48;
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
};
TEST(BackPropTest, LayerVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim;
Layer<T, TestConfig> weights;
Layer<T, TestConfig> grad;
ForwardLayer<T, TestConfig> forward;
ForwardLayer<T, TestConfig> backward = {};
Layer<TC, TestConfig> c_weights;
ForwardLayer<TC, TestConfig> c_forward;
std::array<T, kOutputSize> y;
std::array<T, kOutputSize> dy;
std::array<TC, kOutputSize> c_y;
const size_t num_tokens = 3;
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
RandInit(forward.input, 1.0, gen);
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
Complexify(forward.input, c_forward.input);
auto func = [&]() {
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim);
};
memset(&grad, 0, sizeof(grad));
ApplyLayer(weights, forward, num_tokens, y.data());
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
TestGradient(backward.input, c_forward.input, func, 1e-11, 1e-11,
__LINE__);
TestGradient(grad, c_weights, func, 1e-11);
}
}
TEST(BackPropTest, EndToEnd) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
WeightsWrapper<T, TestConfig> weights;
WeightsWrapper<T, TestConfig> grad;
ForwardPass<T, TestConfig> forward;
ForwardPass<T, TestConfig> backward;
WeightsWrapper<TC, TestConfig> c_weights;
ForwardPass<TC, TestConfig> c_forward;
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(10, gen);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0, gen);
CrossEntropyLossForwardPass(prompt, weights.get(), forward);
grad.clear();
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward, grad.get(), backward);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 1e-11);
}
}
template<typename T, typename TConfig>
void MulByConstAndAddT(T c, const Layer<T, TConfig>& x,
Layer<T, TConfig>& out) {
MulByConstAndAddT(c, x.pre_attention_norm_scale,
out.pre_attention_norm_scale);
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w);
MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale);
MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w);
MulByConstAndAddT(c, x.linear_w, out.linear_w);
}
template<typename T, typename TConfig>
void MulByConstAndAddT(T c, const Weights<T, TConfig>& x,
Weights<T, TConfig>& out) {
static constexpr size_t kLayers = TConfig::kLayers;
MulByConstAndAddT(c, x.embedder_input_embedding,
out.embedder_input_embedding);
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) {
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
}
}
// Evaluates forward pass on a batch.
template<typename T, typename TConfig>
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
const WeightsWrapper<T, TConfig>& weights,
ForwardPass<T, TConfig>& forward) {
T loss = 0.0;
for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
}
T scale = 1.0 / batch.size();
return loss * scale;
}
// Evaluates forward pass on a batch by applying gradient with the given
// learning rate. Does not update weights, but uses the given tmp weights
// instead.
template<typename T, typename TConfig>
T CrossEntropyLossForwardPass(T learning_rate,
const std::vector<Prompt>& batch,
const WeightsWrapper<T, TConfig>& weights,
const WeightsWrapper<T, TConfig>& grad,
WeightsWrapper<T, TConfig>& tmp,
ForwardPass<T, TConfig>& forward) {
tmp.copy(weights);
const T scale = -learning_rate / batch.size();
MulByConstAndAddT(scale, grad.get(), tmp.get());
return CrossEntropyLossForwardPass(batch, tmp, forward);
}
// Uses line search in the negative gradient direction to update weights. We do
// this so that we can test that each step during the gradient descent can
// decrease the objective function value.
template<typename T, typename TConfig>
T FindOptimalUpdate(const WeightsWrapper<T, TConfig>& grad,
WeightsWrapper<T, TConfig>& weights,
WeightsWrapper<T, TConfig>& tmp,
ForwardPass<T, TConfig>& forward,
const std::vector<Prompt>& batch,
T loss, T initial_learning_rate) {
T lr0 = initial_learning_rate;
T loss0 = CrossEntropyLossForwardPass(
lr0, batch, weights, grad, tmp, forward);
for (size_t iter = 0; iter < 30; ++iter) {
T lr1 = lr0 * 0.5;
T loss1 = CrossEntropyLossForwardPass(
lr1, batch, weights, grad, tmp, forward);
if (loss0 < loss && loss1 >= loss0) {
break;
}
loss0 = loss1;
lr0 = lr1;
}
for (size_t iter = 0; iter < 30; ++iter) {
T lr1 = lr0 * 2.0;
T loss1 = CrossEntropyLossForwardPass(
lr1, batch, weights, grad, tmp, forward);
if (loss1 >= loss0) {
break;
}
loss0 = loss1;
lr0 = lr1;
}
const T scale = -lr0 / batch.size();
MulByConstAndAddT(scale, grad.get(), weights.get());
return lr0;
}
TEST(BackProptest, Convergence) {
std::mt19937 gen(42);
using T = float;
using TC = std::complex<double>;
WeightsWrapper<T, TestConfig> weights;
WeightsWrapper<T, TestConfig> grad;
WeightsWrapper<T, TestConfig> tmp;
ForwardPass<T, TestConfig> forward;
ForwardPass<T, TestConfig> backward;
WeightsWrapper<TC, TestConfig> c_weights;
ForwardPass<TC, TestConfig> c_forward;
constexpr size_t kBatchSize = 5;
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
T learning_rate = 0.01;
RandInit(weights.get(), T(1.0), gen);
printf("Sample batch:\n");
for (size_t i = 0; i < 10; ++i) {
ReverseSequenceSampler::LogPrompt(training_task.Sample(gen));
}
T prev_loss = std::numeric_limits<T>::max();
bool stop = false;
size_t step = 0;
while (!stop) {
T loss = 0.0;
grad.clear();
std::mt19937 sgen(42);
std::vector<Prompt> batch = training_task.SampleBatch(kBatchSize, sgen);
for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward, grad.get(), backward);
}
if (step % 250 == 0) {
printf("Checking gradient...\n");
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
TC scale = batch.size();
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
};
TestGradient(grad.get(), c_weights.get(), func, 5e-3f);
}
loss /= batch.size();
EXPECT_LT(loss, prev_loss);
stop = step >= 10000 || loss < 1e-2;
if (step % 10 == 0 || stop) {
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
step, loss, learning_rate);
}
if (!stop) {
learning_rate = FindOptimalUpdate(
grad, weights, tmp, forward, batch, loss, learning_rate);
++step;
}
prev_loss = loss;
}
EXPECT_LT(step, 1000);
}
} // namespace gcpp

268
backprop/backward_test.cc Normal file
View File

@ -0,0 +1,268 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <algorithm>
#include <array>
#include <complex>
#include <random>
#include <vector>
#include "backprop/backward_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/compress.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "backprop/forward-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
hwy::ThreadPool pool(8);
std::mt19937 gen(42);
HWY_ALIGN std::array<float, kRows * kCols> weights;
HWY_ALIGN std::array<float, kTokens * kCols> x;
HWY_ALIGN std::array<float, kTokens * kRows> dy;
HWY_ALIGN std::array<float, kRows * kCols> grad;
HWY_ALIGN std::array<float, kTokens * kCols> dx;
HWY_ALIGN std::array<float, kRows * kCols> grad_scalar;
HWY_ALIGN std::array<float, kTokens * kCols> dx_scalar;
using TC = std::complex<double>;
std::array<TC, kRows * kCols> c_weights;
std::array<TC, kTokens * kCols> c_x;
std::array<TC, kTokens * kRows> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
hwy::ZeroBytes(&grad, sizeof(grad));
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
grad, dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
}
void TestMultiHeadMatMulVJP() {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
hwy::ThreadPool pool(8);
std::mt19937 gen(42);
HWY_ALIGN std::array<float, kRows * kCols * kHeads> weights;
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> x;
HWY_ALIGN std::array<float, kRows * kCols * kHeads> grad;
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> dx;
HWY_ALIGN std::array<float, kTokens * kRows> dy;
HWY_ALIGN std::array<float, kRows * kCols * kHeads> grad_scalar;
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> dx_scalar;
using TC = std::complex<double>;
std::array<TC, kRows * kCols * kHeads> c_weights;
std::array<TC, kTokens * kCols * kHeads> c_x;
std::array<TC, kTokens * kRows> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
};
hwy::ZeroBytes(&grad, sizeof(grad));
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
}
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
hwy::ThreadPool pool(8);
std::mt19937 gen(42);
HWY_ALIGN std::array<float, N> weights;
HWY_ALIGN std::array<float, K * N> x;
HWY_ALIGN std::array<float, N> grad;
HWY_ALIGN std::array<float, K * N> dx;
HWY_ALIGN std::array<float, K * N> dy;
HWY_ALIGN std::array<float, N> grad_scalar;
HWY_ALIGN std::array<float, K * N> dx_scalar;
using TC = std::complex<double>;
std::array<TC, N> c_weights;
std::array<TC, K * N> c_x;
std::array<TC, K * N> c_y;
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
};
hwy::ZeroBytes(&grad, sizeof(grad));
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), N, K);
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
}
}
struct TestConfig {
static constexpr int kSeqLen = 24;
static constexpr int kVocabSize = 16;
static constexpr int kModelDim = 32;
static constexpr int kHeads = 3;
static constexpr int kQKVDim = 16;
static constexpr int kFFHiddenDim = 64;
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr int kKVHeads = 1;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kGriffinLayers = 0;
static constexpr int kNumTensorScales = 0;
};
void TestEndToEnd() {
std::mt19937 gen(42);
hwy::ThreadPool pool(0);
WeightsWrapper<float, TestConfig> weights;
WeightsWrapper<float, TestConfig> grad;
ActivationsWrapper<float, TestConfig> forward0;
ActivationsWrapper<float, TestConfig> forward1;
ActivationsWrapper<float, TestConfig> backward;
using TC = std::complex<double>;
WeightsWrapper<TC, TestConfig> c_weights;
ForwardPass<TC, TestConfig> c_forward;
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(10, gen);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0f, gen);
float loss0 = CrossEntropyLossForwardPass(
prompt, weights.get(), forward0.get());
float loss1 = CrossEntropyLossForwardPass<TestConfig, WeightsF, LayerF>(
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
pool);
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 1e-5);
grad.clear();
CrossEntropyLossBackwardPass(
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
pool);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 2e-3f);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(BackwardTest);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

51
backprop/common_scalar.cc Normal file
View File

@ -0,0 +1,51 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
float EmbeddingScaling(int model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(model_dim))));
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(EmbeddingScaling);
float EmbeddingScaling(int model_dim) {
return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim);
}
} // namespace gcpp
#endif // HWY_ONCE

119
backprop/common_scalar.h Normal file
View File

@ -0,0 +1,119 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#include <complex>
namespace gcpp {
template<typename T, typename U>
U DotT(const T* a, const U* b, size_t N) {
U sum = {};
for (size_t i = 0; i < N; ++i) {
sum += a[i] * b[i];
}
return sum;
}
template<>
std::complex<double> DotT(const float* a, const std::complex<double>* b,
size_t N) {
std::complex<double> sum = {};
for (size_t i = 0; i < N; ++i) {
sum += static_cast<double>(a[i]) * b[i];
}
return sum;
}
template<typename T>
void MulByConstT(T c, T* x, size_t N) {
for (size_t i = 0; i < N; ++i) {
x[i] *= c;
}
}
// out += c * x
template<typename T>
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += c * x[i];
}
}
template<typename T, size_t N>
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
MulByConstAndAddT(c, x.data(), out.data(), N);
}
template<typename T>
void AddFromT(const T* a, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += a[i];
}
}
template<typename T>
T SquaredL2(const T* x, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += x[i] * x[i];
}
return sum;
}
template<typename T>
T Gelu(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
const T x3 = x * x * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
return x * cdf;
}
template<typename T, typename U>
void Rope(T* x, U base, size_t N, int i) {
const size_t N2 = N / 2;
for (size_t dim = 0; dim < N2; ++dim) {
const T freq_exponents = T(2 * dim) / T(N);
const T timescale = std::pow(base, freq_exponents);
const T theta = T(i) / timescale;
const T cos_val = std::cos(theta);
const T sin_val = std::sin(theta);
const T x0 = x[dim];
const T x1 = x[dim + N2];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + N2] = x0 * sin_val + x1 * cos_val;
}
}
template<typename T>
void Rope(T* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
template<typename T>
void Rope(std::complex<T>* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
float EmbeddingScaling(int model_dim);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_

292
backprop/forward-inl.h Normal file
View File

@ -0,0 +1,292 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "gemma/common-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output,
size_t model_dim) {
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim);
}
}
template<typename WT, typename XT, typename OutT>
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
size_t model_dim, size_t num_tokens,
OutT* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, output + offset, model_dim);
}
}
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
const std::vector<int>& prompt,
size_t context_size,
size_t vocab_size,
hwy::ThreadPool& pool) {
HWY_ASSERT(!prompt.empty());
float loss = 0.0f;
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
if (pos + 1 < context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = prompt[pos + 1];
loss += std::log(probs[pos * vocab_size + next_token]);
}
float scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename TConfig, template<typename> typename LayerT>
void ApplyForwardLayer(const LayerT<TConfig>& weights,
ForwardLayer<float, TConfig>& activations,
size_t num_tokens,
float* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
activations.input.data(), kModelDim, num_tokens,
activations.pre_att_rms_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
weights.qkv_einsum_w, 0,
activations.pre_att_rms_out.data() + pos * kModelDim, nullptr,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const float* HWY_RESTRICT k2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
Softmax(head_att, pos + 1);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
float* HWY_RESTRICT att_out =
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
float* HWY_RESTRICT v2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
});
hwy::ZeroBytes(activations.attention_out.data(),
num_tokens * kModelDim * sizeof(activations.attention_out[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
MatVec<kModelDim, kQKVDim>(
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
nullptr, activations.att_post1.data() + pos * kModelDim, pool);
AddFrom(activations.att_post1.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
activations.attention_out.data(), kModelDim, num_tokens,
activations.bf_pre_ffw_rms_out.data(), pool);
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kFFHiddenDim * 2, kModelDim>(
weights.gating_einsum_w, 0,
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT out =
activations.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
float* HWY_RESTRICT out_gated =
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = Load(df, out + i);
const auto x = Load(df, out_mul + i);
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kModelDim, kFFHiddenDim>(
weights.linear_w, 0,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
nullptr, output + pos * kModelDim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * kModelDim,
output + pos * kModelDim, kModelDim);
}
}
template <typename TConfig, template<typename> typename WeightsT,
template<typename> typename LayerT>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const WeightsT<TConfig>& weights,
ForwardPass<float, TConfig>& forward,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale);
static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(context_size > 0);
HWY_DASSERT(context_size < prompt.size());
const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling,
forward.layers[0].input.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
// TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < kLayers ?
forward.layers[layer + 1].input.data() :
forward.final_layer_output.data();
ApplyForwardLayer<TConfig, LayerT>(
*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, pool);
}
ApplyRMSNorm(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
kModelDim, num_tokens, forward.final_norm_output.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0,
forward.final_norm_output.data() + pos * kModelDim, nullptr,
forward.logits.data() + pos * kVocabSize, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
}
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
}
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
kVocabSize, pool);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

80
backprop/forward.cc Normal file
View File

@ -0,0 +1,80 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/forward.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "backprop/forward-inl.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig>
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ByteStorageT& weights_u8,
ByteStorageT& forward_u8,
hwy::ThreadPool& pool) {
const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
auto& forward =
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
return CrossEntropyLossForwardPass<TConfig, WeightsF, LayerF>(
prompt.tokens, prompt.context_size, weights, forward, pool);
}
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B>(
prompt, weights, forward, pool);
case Model::GEMMA_TINY:
return CrossEntropyLossForwardPass<ConfigGemmaTiny>(
prompt, weights, forward, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
ByteStorageT& forward, hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
model, prompt, weights, forward, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

33
backprop/forward.h Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include <vector>
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
float CrossEntropyLossForwardPass(
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
ByteStorageT& forward, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_

288
backprop/forward_scalar.h Normal file
View File

@ -0,0 +1,288 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <complex>
#include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/weights.h"
namespace gcpp {
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
// y = w * x is N x K matrix in column-major order.
template<typename T>
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
}
}
}
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
// column-major order and y = w' * x is N x K matrix in column-major order,
// where w' is the rearrangement of w into an N x HM matrix.
template<typename T>
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
size_t M, size_t K) {
memset(y, 0, N * K * sizeof(y[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t h = 0; h < H; ++h) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
constexpr T eps(1e-6);
for (size_t i = 0; i < K; ++i) {
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; j++) {
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
}
}
}
template<typename T>
void Softmax(T* x, size_t N) {
T sum = {};
auto maxreal = std::real(x[0]);
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
}
}
for (size_t i = 0; i < N; ++i) {
x[i] = std::exp(x[i] - maxreal);
sum += x[i];
}
T scale = T(1.0) / sum;
for (size_t i = 0; i < N; ++i) {
x[i] *= scale;
}
}
template<typename T>
void Softmax(T* x, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
Softmax(x + i * N, N);
}
}
template<typename T>
void Softcap(T* x, size_t N) {
T cap = 30.0;
T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) {
x[i] = cap * std::tanh(x[i] * inv_cap);
}
}
template<typename T>
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
T* y = out + i * N;
for (size_t j = 0; j < N; ++j) {
y[j] = x2[j] * Gelu(x1[j]);
}
}
}
template<typename T>
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
T* y, size_t N) {
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i];
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
MulByConstT(scaling, y + i * N, N);
}
}
template<typename T>
void MaskedAttention(const T* qkv, T* output, size_t num_tokens,
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
const size_t qoffset = pos * (kHeads + 2) * kQKVDim;
const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen;
const T* q = qkv + qoffset + head * kQKVDim;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
output[aoffset + pos2] = DotT(q, k, kQKVDim);
}
}
}
}
template<typename T>
void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
Softmax(x + offset, pos + 1);
memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
}
}
}
template<typename T>
void MixByAttention(const T* qkv, const T* attention, T* output,
size_t num_tokens, size_t kHeads, size_t kQKVDim,
size_t kSeqLen) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen];
T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim];
memset(out, 0, kQKVDim * sizeof(out[0]));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
const T* v = &qkv[v_offset];
MulByConstAndAddT(att[pos2], v, out, kQKVDim);
}
}
}
}
template<typename T, typename TConfig>
void ApplyLayer(const Layer<T, TConfig>& weights,
ForwardLayer<T, TConfig>& activations,
size_t num_tokens, T* output) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim));
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
activations.pre_att_rms_out.data(), kModelDim, num_tokens);
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim,
num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * kQKVDim, kQKVDim, pos);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
}
MaskedAttention(activations.qkv.data(), activations.att.data(),
num_tokens, kHeads, kQKVDim, kSeqLen);
MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen);
MixByAttention(activations.qkv.data(), activations.att.data(),
activations.att_out.data(), num_tokens, kHeads, kQKVDim,
kSeqLen);
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
activations.attention_out.data(), kHeads, kModelDim, kQKVDim,
num_tokens);
AddFromT(activations.input.data(), activations.attention_out.data(),
num_tokens * kModelDim);
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens);
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim,
num_tokens);
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
kFFHiddenDim, num_tokens);
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(),
output, kModelDim, kFFHiddenDim, num_tokens);
AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim);
}
template<typename T>
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
T loss = {};
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = tokens[i + 1];
loss += std::log(x[i * V + next_token]);
}
T scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template<typename T, typename TConfig>
T CrossEntropyLossForwardPass(const Prompt& prompt,
const Weights<T, TConfig>& weights,
ForwardPass<T, TConfig>& forward) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(kModelDim);
InputEmbedding(weights.embedder_input_embedding.data(), tokens,
kEmbScaling, forward.layers[0].input.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {
T* output = layer + 1 < kLayers ?
forward.layers[layer + 1].input.data() :
forward.final_layer_output.data();
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
output);
}
RMSNormT(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
forward.final_norm_output.data(), kModelDim, num_tokens);
MatMulT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(),
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
Softcap(forward.logits.data(), num_tokens * kVocabSize);
memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0]));
Softmax(forward.probs.data(), kVocabSize, num_tokens);
return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_

127
backprop/optimize_test.cc Normal file
View File

@ -0,0 +1,127 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "backprop/backward.h"
#include "backprop/forward.h"
#include "backprop/optimizer.h"
#include "backprop/sampler.h"
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "gtest/gtest.h"
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
hwy::ThreadPool pool(0);
std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY;
ByteStorageT weights = AllocateWeights(model_type, pool);
ByteStorageT grad = AllocateWeights(model_type, pool);
ByteStorageT grad_m = AllocateWeights(model_type, pool);
ByteStorageT grad_v = AllocateWeights(model_type, pool);
ByteStorageT forward = AllocateForwardPass(model_type);
ByteStorageT backward = AllocateForwardPass(model_type);
ByteStorageT inference = AllocateInferenceState(model_type);
auto kv_cache = CreateKVCache(model_type);
size_t max_tokens = 32;
size_t max_generated_tokens = 16;
float temperature = 1.0f;
int verbosity = 0;
const auto accept_token = [](int) { return true; };
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
auto stream_token = [&reply](int token, float) {
reply.push_back(token);
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
GenerateGemma(model_type, weights, inference, runtime, prompt, 0,
kv_cache, pool, timing_info);
return reply;
};
auto verify = [&](const Prompt& prompt) {
auto context = prompt.context();
std::vector<int> reply = generate(context);
bool ok = true;
for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) {
if (i >= reply.size() || reply[i] != prompt.tokens[i]) {
ok = false;
}
}
return ok;
};
RandInitWeights(model_type, weights, pool, gen);
ZeroInitWeights(model_type, grad_m, pool);
ZeroInitWeights(model_type, grad_v, pool);
printf("Initial weights:\n");
LogWeightStats(model_type, weights);
constexpr size_t kBatchSize = 8;
float learning_rate = 0.0005f;
ReverseSequenceSampler training_task({
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
size_t steps = 0;
float prev_loss = std::numeric_limits<float>::max();
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
ZeroInitWeights(model_type, grad, pool);
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(
model_type, prompt, weights, forward, pool);
CrossEntropyLossBackwardPass(
model_type, prompt, weights, forward, grad, backward, pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;
const float scale = -learning_rate / kBatchSize;
UpdateWeights(model_type, grad, scale, weights, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
printf("Batch gradient:\n");
LogWeightStats(model_type, grad);
}
if (total_loss < 0.5f) {
break;
}
prev_loss = total_loss;
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
LogWeightStats(model_type, weights);
EXPECT_LT(steps, 3000);
EXPECT_EQ(num_ok, kBatchSize);
}
} // namespace gcpp

121
backprop/optimizer.cc Normal file
View File

@ -0,0 +1,121 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/optimizer.h"
#include <random>
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
class WeightInitializer {
public:
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
template <size_t N>
void operator()(const char* name, std::array<float, N>& tensor) {
for (size_t i = 0; i < N; ++i) {
tensor[i] = dist_(gen_);
}
}
private:
std::normal_distribution<float> dist_;
std::mt19937& gen_;
};
template <typename TConfig>
void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python
// version.
WeightInitializer init(gen);
ForEachTensor1<float, TConfig>(init, weights);
}
class WeightUpdater {
public:
explicit WeightUpdater(float lr) : lr_(lr) {}
template <size_t kCapacity>
void operator()(const char* name, const std::array<float, kCapacity>& grad,
std::array<float, kCapacity>& weights) {
for (size_t i = 0; i < kCapacity; ++i) {
weights[i] += lr_ * grad[i];
}
}
private:
float lr_;
};
template <typename TConfig>
void UpdateWeights(const ByteStorageT& grad_u8, float scale,
ByteStorageT& weights_u8, hwy::ThreadPool& pool) {
const auto& grad =
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightUpdater updater(scale);
ForEachTensor2<float, TConfig>(updater, grad, weights);
}
} // namespace
void RandInitWeights(Model model, ByteStorageT& weights_u8,
hwy::ThreadPool& pool, std::mt19937& gen) {
switch (model) {
case Model::GEMMA_2B:
RandInitWeights<ConfigGemma2B>(weights_u8, pool, gen);
break;
case Model::GEMMA_7B:
RandInitWeights<ConfigGemma7B>(weights_u8, pool, gen);
break;
case Model::GRIFFIN_2B:
RandInitWeights<ConfigGriffin2B>(weights_u8, pool, gen);
break;
case Model::GEMMA_TINY:
RandInitWeights<ConfigGemmaTiny>(weights_u8, pool, gen);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
ByteStorageT& weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
UpdateWeights<ConfigGemma2B>(grad, scale, weights, pool);
break;
case Model::GEMMA_7B:
UpdateWeights<ConfigGemma7B>(grad, scale, weights, pool);
break;
case Model::GRIFFIN_2B:
UpdateWeights<ConfigGriffin2B>(grad, scale, weights, pool);
break;
case Model::GEMMA_TINY:
UpdateWeights<ConfigGemmaTiny>(grad, scale, weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

34
backprop/optimizer.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include <random>
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen);
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
ByteStorageT& weights, hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_

34
backprop/prompt.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
#include <stddef.h>
#include <vector>
namespace gcpp {
struct Prompt {
std::vector<int> tokens;
size_t context_size;
std::vector<int> context() const {
return std::vector<int>(tokens.begin(), tokens.begin() + context_size);
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_

84
backprop/sampler.h Normal file
View File

@ -0,0 +1,84 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <vector>
#include "backprop/prompt.h"
namespace gcpp {
class PromptSampler {
public:
virtual Prompt Sample(std::mt19937& gen) = 0;
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
std::vector<Prompt> batch;
batch.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
batch.emplace_back(Sample(gen));
}
return batch;
}
};
class ReverseSequenceSampler : public PromptSampler {
public:
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
: token_dist_(0, 9) {
for (int i = 0; i < length_histo.size(); ++i) {
const int count = length_histo[i];
for (int j = 0; j < count; ++j) {
length_lut_.push_back(i + 1);
}
}
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
}
static constexpr int kReverseToken = 10;
static constexpr int kEndToken = 11;
Prompt Sample(std::mt19937& gen) override {
Prompt prompt;
int len = length_lut_[length_dist_(gen)];
prompt.tokens.resize(2 * len + 2);
prompt.tokens[len] = kReverseToken;
prompt.tokens[2 * len + 1] = kEndToken;
for (size_t i = 0; i < len; ++i) {
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
}
prompt.context_size = len + 1;
return prompt;
}
static void LogPrompt(const Prompt& prompt) {
static const char* kVocab[] = {
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
};
for (int token : prompt.tokens) printf("%s", kVocab[token]);
printf(" [context_size: %zu]\n", prompt.context_size);
}
private:
std::uniform_int_distribution<> token_dist_;
std::uniform_int_distribution<> length_dist_;
std::vector<int> length_lut_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_

195
backprop/test_util.h Normal file
View File

@ -0,0 +1,195 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#include <array>
#include <complex>
#include <random>
#include "gemma/weights.h"
#include "gtest/gtest.h"
namespace gcpp {
template<typename T, size_t kLen>
void RandInit(std::array<T, kLen>& x, T stddev, std::mt19937& gen) {
std::normal_distribution<T> dist(0.0, stddev);
for (size_t i = 0; i < kLen; ++i) {
x[i] = dist(gen);
}
}
template<typename T, typename TConfig>
void RandInit(Layer<T, TConfig>& w, T stddev, std::mt19937& gen) {
RandInit(w.pre_attention_norm_scale, stddev, gen);
RandInit(w.attn_vec_einsum_w, stddev, gen);
RandInit(w.qkv_einsum_w, stddev, gen);
RandInit(w.pre_ffw_norm_scale, stddev, gen);
RandInit(w.gating_einsum_w, stddev, gen);
RandInit(w.linear_w, stddev, gen);
}
template<typename T, typename TConfig>
void RandInit(Weights<T, TConfig>& w, T stddev, std::mt19937& gen) {
static constexpr size_t kLayers = TConfig::kLayers;
RandInit(w.embedder_input_embedding, stddev, gen);
RandInit(w.final_norm_scale, stddev, gen);
for (size_t i = 0; i < kLayers; ++i) {
RandInit(*w.GetLayer(i), stddev, gen);
}
}
template<typename T, typename U, size_t kLen>
void Complexify(const std::array<T, kLen>& x,
std::array<std::complex<U>, kLen>& c_x) {
for (size_t i = 0; i < kLen; ++i) {
c_x[i] = std::complex<U>(x[i], 0.0);
}
}
template<typename T, typename U, typename TConfig>
void Complexify(const Layer<T, TConfig>& w,
Layer<std::complex<U>, TConfig>& c_w) {
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
Complexify(w.linear_w, c_w.linear_w);
}
template<typename T, typename U, typename TConfig>
void Complexify(const Weights<T, TConfig>& w,
Weights<std::complex<U>, TConfig>& c_w) {
static constexpr size_t kLayers = TConfig::kLayers;
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
Complexify(w.final_norm_scale, c_w.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) {
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
}
}
template<typename T, typename U, size_t N>
void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
double max_abs_err, double max_rel_err, int line) {
double sum0 = 0;
double sum1 = 0;
double sum01 = 0;
for (size_t i = 0; i < N; ++i) {
sum0 += actual[i] * actual[i];
sum1 += expected[i] * expected[i];
sum01 += actual[i] * expected[i];
ASSERT_NEAR(actual[i], expected[i],
std::max(max_abs_err, std::abs(expected[i]) * max_rel_err))
<< "line: " << line << " dim=" << N << " i=" << i;
}
if (sum0 > 1e-40) {
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
ASSERT_NEAR(norm_dot, 1.0, 1e-7)
<< "line: " << line << " sum0: " << sum0 << " sum1: " << sum1
<< " sum01: " << sum01;
}
}
// Compute gradient with the finite difference method in the complex plane.
// If f : R->R is the tested function and F : C->C is its extension on the
// complex plane so that F is complex differentiable in x, then
//
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
//
// which means that
//
// F'(x) ~= Imag(F(x + ih)) / h
//
// This method is more numerically stable than the real-valued finite difference
// method since we don't need to subtract floating point numbers that are near
// to each other.
template<typename T, typename U, size_t N, typename FUNC>
void TestGradient(const std::array<T, N>& grad,
std::array<std::complex<U>, N>& x, FUNC func,
U step, T max_abs_err, T max_rel_err, int line) {
std::array<T, N> exp_grad;
const U inv_step = 1.0 / step;
for (size_t i = 0; i < N; ++i) {
const U x0 = std::real(x[i]);
const std::complex<U> x1 = std::complex<U>(x0, step);
x[i] = x1;
const std::complex<U> f1 = func();
exp_grad [i] = std::imag(f1) * inv_step;
x[i] = x0;
}
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
}
template<size_t N, typename FUNC>
void TestGradient(const std::array<float, N>& grad,
std::array<std::complex<float>, N>& x, FUNC func,
float max_abs_err, float max_rel_error, int line) {
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line);
}
template<size_t N, typename FUNC>
void TestGradient(const std::array<float, N>& grad,
std::array<std::complex<double>, N>& x, FUNC func,
float max_abs_err, float max_rel_error, int line) {
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
}
template<size_t N, typename FUNC>
void TestGradient(const std::array<double, N>& grad,
std::array<std::complex<double>, N>& x, FUNC func,
double max_abs_err, double max_rel_error, int line) {
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
}
template<typename T, typename U, typename TConfig, typename FUNC>
void TestGradient(const Layer<T, TConfig>& grad,
Layer<std::complex<U>, TConfig>& c_weights,
FUNC func, T max_err) {
TestGradient(grad.pre_attention_norm_scale,
c_weights.pre_attention_norm_scale,
func, max_err, max_err, __LINE__);
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale,
func, max_err, max_err, __LINE__);
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w,
func, max_err, max_err, __LINE__);
TestGradient(grad.linear_w, c_weights.linear_w,
func, max_err, max_err, __LINE__);
}
template<typename T, typename U, typename TConfig, typename FUNC>
void TestGradient(const Weights<T, TConfig>& grad,
Weights<std::complex<U>, TConfig>& c_weights,
FUNC func, T max_err) {
TestGradient(grad.embedder_input_embedding,
c_weights.embedder_input_embedding,
func, 2 * max_err, max_err, __LINE__);
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
func, max_err, max_err, __LINE__);
for (int i = 0; i < TConfig::kLayers; ++i) {
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_

View File

@ -484,9 +484,10 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
const hn::ScalableTag<OutT> d;
const size_t ofs = idx_batch * kBatch;
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
const size_t batch =
idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
Traits::Decompress(d, compressed.size(), compressed.data(),
compressed_ofs + ofs, out + ofs, num);
compressed_ofs + ofs, out + ofs, batch);
});
const double t1 = hwy::platform::Now();
@ -495,6 +496,16 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
}
// Returns dot product with `vec_aligned` of length `num`.
template <bool kVecEO, class DF, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const std::array<float, kCapacity>& w, size_t ofs,
const VecT* x, size_t num) {
HWY_DASSERT(ofs + num <= kCapacity);
HWY_DASSERT(hn::IsAligned(df, x));
using Traits = CompressTraits<float>;
return Traits::Dot(df, w.size(), w.data(), ofs, x, num);
}
// Returns dot product with `vec_aligned` of length `num`.
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,

38
gemma/activations.cc Normal file
View File

@ -0,0 +1,38 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
namespace gcpp {
ByteStorageT AllocateForwardPass(Model model) {
switch (model) {
case Model::GEMMA_2B:
return ForwardPass<float, ConfigGemma2B>::Allocate();
case Model::GEMMA_7B:
return ForwardPass<float, ConfigGemma7B>::Allocate();
case Model::GRIFFIN_2B:
return ForwardPass<float, ConfigGriffin2B>::Allocate();
case Model::GEMMA_TINY:
return ForwardPass<float, ConfigGemmaTiny>::Allocate();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

85
gemma/activations.h Normal file
View File

@ -0,0 +1,85 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#include "gemma/common.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
template <typename T, typename TConfig>
struct ForwardLayer {
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
std::array<T, kSeqLen * kModelDim> input;
std::array<T, kSeqLen * kModelDim> pre_att_rms_out;
std::array<T, kSeqLen * (kHeads + 2) * kQKVDim> qkv;
std::array<T, kSeqLen * kHeads * kSeqLen> att;
std::array<T, kSeqLen * kHeads * kQKVDim> att_out;
std::array<T, kSeqLen * kModelDim> att_post1;
std::array<T, kSeqLen * kModelDim> attention_out;
std::array<T, kSeqLen * kModelDim> bf_pre_ffw_rms_out;
std::array<T, kSeqLen * kFFHiddenDim * 2> ffw_hidden;
std::array<T, kSeqLen * kFFHiddenDim> ffw_hidden_gated;
};
template <typename T, typename TConfig>
struct ForwardPass {
ForwardPass() {} // prevents placement-new calling memset
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
std::array<ForwardLayer<T, TConfig>, kLayers> layers;
std::array<T, kSeqLen * kModelDim> final_layer_output;
std::array<T, kSeqLen * kModelDim> final_norm_output;
std::array<T, kSeqLen * kVocabSize> logits;
std::array<T, kSeqLen * kVocabSize> probs;
static ByteStorageT Allocate() {
return hwy::AllocateAligned<uint8_t>(sizeof(ForwardPass<T, TConfig>));
}
};
// Owns activations and undoes the type erasure of AllocateAligned.
template<typename T, typename TConfig>
class ActivationsWrapper {
using WrappedT = ForwardPass<T, TConfig>;
public:
ActivationsWrapper()
: data_(WrappedT::Allocate()),
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
const WrappedT& get() const { return activations_; }
WrappedT& get() { return activations_; }
private:
ByteStorageT data_;
WrappedT& activations_;
};
ByteStorageT AllocateForwardPass(Model model);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_

68
gemma/common-inl.h Normal file
View File

@ -0,0 +1,68 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "gemma/activations.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#endif
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

30
gemma/common.h Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include "hwy/aligned_allocator.h"
namespace gcpp {
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

View File

@ -142,6 +142,38 @@ struct ConfigGemma2B {
using WeightT = GEMMA_WEIGHT_T;
};
struct ConfigGemmaTiny {
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 64;
static constexpr int kFFHiddenDim = 128;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
struct ConfigGriffin2B {
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.

View File

@ -22,6 +22,7 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h"
#include "gemma/common-inl.h"
#include "gemma/ops.h"
#include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h"
@ -53,6 +54,7 @@
#include "compression/io.h" // Path
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -72,63 +74,6 @@ constexpr bool kShowTokenization = false;
namespace gcpp {
template <class TConfig>
struct Layer {
Layer() = default;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
template <class T, size_t N>
using ArrayT = std::array<T, N>;
union {
struct {
ArrayT<float, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<float, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiasDim> attention_output_biases;
};
struct {
ArrayT<float, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<float, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<float, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> conv_biases;
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> a;
} griffin;
};
ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
ArrayT<float, kModelDim> pre_attention_norm_scale;
ArrayT<float, kModelDim> pre_ffw_norm_scale;
ArrayT<float, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
ArrayT<float, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
@ -146,41 +91,6 @@ float ScaleWeights(float* data, size_t len) {
return scale;
}
// Array instead of single large allocation for parallel mem init. Split out of
// Weights so that only these pointers are initialized.
template <class TConfig>
struct LayerPointers {
explicit LayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->layers[task] = hwy::AllocateAligned<Layer<TConfig>>(1);
});
}
using TLayer = Layer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
};
template <class TConfig>
struct Weights {
// No ctor/dtor, allocated via AllocateAligned.
std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
std::array<float, TConfig::kModelDim> final_norm_scale;
LayerPointers<TConfig> layer_ptrs;
std::array<float, TConfig::kNumTensorScales> scales;
const Layer<TConfig>* GetLayer(size_t layer) const {
return layer_ptrs.layers[layer].get();
}
Layer<TConfig>* GetLayer(size_t layer) {
return layer_ptrs.layers[layer].get();
}
};
template <typename TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
const Path& checkpoint, hwy::ThreadPool& pool,
@ -191,11 +101,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
checkpoint.path.c_str());
}
using TWeights = Weights<TConfig>;
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<TConfig>(pool);
ByteStorageT weights_u8 = AllocateWeights<float, TConfig>(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0;
FILE* fptr;
@ -228,7 +135,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
Layer<TConfig>* layer_view = weights->GetLayer(layer);
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
#define READ_WEIGHTS(name) \
do { \
@ -305,7 +212,7 @@ template <class TConfig>
struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::Layer<TConfig>;
using TLayer = gcpp::LayerF<TConfig>;
using WeightT = typename TConfig::WeightT;
static constexpr size_t kHeads = TLayer::kHeads;
@ -399,13 +306,13 @@ struct CompressedWeights {
template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
Weights<TConfig>>;
WeightsF<TConfig>>;
// Aligned.
template <class TConfig, size_t TBatchSize>
struct Activations {
static constexpr size_t kBatchSize = TBatchSize;
using LayerConfig = Layer<TConfig>;
using LayerConfig = LayerF<TConfig>;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
@ -446,6 +353,16 @@ struct Activations {
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
};
template<typename TConfig>
struct InferenceState {
Activations<TConfig, kPrefillBatchSize> prefill;
HWY_ALIGN Activations<TConfig, 1> state;
static ByteStorageT Allocate() {
return hwy::AllocateAligned<uint8_t>(sizeof(InferenceState<TConfig>));
}
};
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
// define an abstract base class.
struct GemmaInterface {
@ -484,6 +401,8 @@ KVCache CreateKVCache(Model type) {
return CreateKVCacheT<ConfigGemma7B>();
case Model::GRIFFIN_2B:
return CreateKVCacheT<ConfigGriffin2B>();
case Model::GEMMA_TINY:
return CreateKVCacheT<ConfigGemmaTiny>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
@ -526,8 +445,8 @@ void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}
template <class Config>
void DeleteLayersPtrs(Weights<Config>* weights) {
weights->layer_ptrs.~LayerPointers<Config>();
void DeleteLayersPtrs(WeightsF<Config>* weights) {
weights->layer_ptrs.~LayerPointers<float, Config>();
}
} // namespace
@ -866,8 +785,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd,
out_mul, pool);
TConfig::kFFBiases ?
layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
@ -892,21 +812,6 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
}
}
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights,
@ -1076,20 +981,15 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
}
}
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma,
template <class TConfig, template<typename> typename WeightsType>
void GenerateImpl(const WeightsType<TConfig>& weights,
Activations<TConfig, kPrefillBatchSize>& prefill_activations,
Activations<TConfig, 1>& activations,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
*gemma.prefill.get();
const WeightsT<TConfig>& weights =
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
size_t prompt_size = prompt.size();
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
@ -1167,7 +1067,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = EOS_ID;
token = runtime_config.eos_id;
}
if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
@ -1177,10 +1077,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
if (!runtime_config.stream_token(token, 0)) {
token = EOS_ID;
token = runtime_config.eos_id;
}
}
if (token == EOS_ID) {
if (token == runtime_config.eos_id) {
if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec =
@ -1192,6 +1092,35 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
}
}
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
const WeightsT<TConfig>& weights =
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
GenerateImpl<TConfig, WeightsT>(
weights, *gemma.prefill.get(), *gemma.state.get(), runtime_config, prompt,
pos, kv_cache, pool, timing_info, layers_output);
}
template <class TConfig>
void GenerateGemma(const ByteStorageT& weights_u8,
ByteStorageT& inference_state_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info, LayersOutputT* layers_output) {
const WeightsF<TConfig>& weights =
*reinterpret_cast<const WeightsF<TConfig>*>(weights_u8.get());
InferenceState<TConfig>& inference_state =
*reinterpret_cast<InferenceState<TConfig>*>(inference_state_u8.get());
GenerateImpl<TConfig, WeightsF>(
weights, inference_state.prefill, inference_state.state, runtime_config,
prompt, pos, kv_cache, pool, timing_info, layers_output);
}
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()
template <class TConfig>
@ -1263,7 +1192,7 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
//
// This avoids repeating the list of tensors between loading and compressing.
template <class TConfig, class Func>
void ForEachTensor(const Weights<TConfig>* weights,
void ForEachTensor(const WeightsF<TConfig>* weights,
CompressedWeights<TConfig>& c_weights, Func& func) {
func("c_embedding",
weights ? weights->embedder_input_embedding.data() : nullptr,
@ -1275,7 +1204,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
const LayerF<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
#define CALL_FUNC(name, member) \
@ -1386,34 +1315,17 @@ void CompressWeights(const Path& weights_path,
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
const hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
LoadWeights<TConfig>(weights_path, pool, scale_for_compression);
Weights<TConfig>* weights =
reinterpret_cast<Weights<TConfig>*>(weights_u8.get());
WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool);
ForEachTensor<TConfig>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path);
weights->layer_ptrs.~LayerPointers<TConfig>();
weights->layer_ptrs.~LayerPointers<float, TConfig>();
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
}
void CompressWeightsT(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
break;
case Model::GEMMA_7B:
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
break;
case Model::GRIFFIN_2B:
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
@ -1421,8 +1333,6 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CompressWeightsT);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size) {
KVCache kv_cache = {};
@ -1528,10 +1438,87 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
void GenerateGemma(Model model, const ByteStorageT& weights,
ByteStorageT& inference_state,
RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
switch (model) {
case Model::GEMMA_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma2B>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
case Model::GEMMA_7B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma7B>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
case Model::GRIFFIN_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGriffin2B>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
case Model::GEMMA_TINY:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemmaTiny>)(
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
pool, timing_info, /*layers_output=*/nullptr);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return LoadWeights<ConfigGemma2B>(weights, pool);
case Model::GEMMA_7B:
return LoadWeights<ConfigGemma7B>(weights, pool);
case Model::GRIFFIN_2B:
return LoadWeights<ConfigGriffin2B>(weights, pool);
case Model::GEMMA_TINY:
return LoadWeights<ConfigGemmaTiny>(weights, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
ByteStorageT AllocateInferenceState(Model model) {
switch (model) {
case Model::GEMMA_2B:
return InferenceState<ConfigGemma2B>::Allocate();
case Model::GEMMA_7B:
return InferenceState<ConfigGemma7B>::Allocate();
case Model::GRIFFIN_2B:
return InferenceState<ConfigGriffin2B>::Allocate();
case Model::GEMMA_TINY:
return InferenceState<ConfigGemmaTiny>::Allocate();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool) {
HWY_DYNAMIC_DISPATCH(CompressWeightsT)
(model, weights, compressed_weights, pool);
switch (model) {
case Model::GEMMA_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma2B>)(
weights, compressed_weights, pool);
break;
case Model::GEMMA_7B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma7B>)(
weights, compressed_weights, pool);
break;
case Model::GRIFFIN_2B:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGriffin2B>)(
weights, compressed_weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
@ -1546,13 +1533,16 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
"2b-it", "7b-it", "gr2b-it"};
"2b-it", "7b-it", "gr2b-it",
"tiny"};
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B};
Model::GEMMA_7B, Model::GRIFFIN_2B,
Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT};
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag,

View File

@ -23,6 +23,7 @@
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
@ -51,8 +52,6 @@ struct KVCache {
rglru_cache; // kModelDim * kGriffinLayers
};
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns error string or nullptr if OK.
@ -68,6 +67,8 @@ using StreamFunc = std::function<bool(int, float)>;
// want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
constexpr int EOS_ID = 1;
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
@ -76,6 +77,7 @@ struct RuntimeConfig {
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
int eos_id = EOS_ID;
};
struct GemmaInterface;
@ -118,6 +120,18 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
void GenerateGemma(Model model, const ByteStorageT& weights,
ByteStorageT& inference_state,
RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info);
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool);
ByteStorageT AllocateInferenceState(Model model);
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool);
@ -125,8 +139,6 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity);
constexpr int EOS_ID = 1;
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -98,7 +98,7 @@ class GemmaTest : public ::testing::Test {
gcpp::Gemma model;
};
TEST_F(GemmaTest, Geography) {
TEST_F(GemmaTest, DISABLED_Geography) {
static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"},
{"How many states does the US have?", "50"},
@ -108,7 +108,7 @@ TEST_F(GemmaTest, Geography) {
TestQuestions(kQA, kNum);
}
TEST_F(GemmaTest, History) {
TEST_F(GemmaTest, DISABLED_History) {
static const char* kQA[][2] = {
{"When was the Battle of Hastings?", "1066"},
{"Who fought at the Battle of Marathon?", "Greek"},
@ -117,7 +117,7 @@ TEST_F(GemmaTest, History) {
TestQuestions(kQA, kNum);
}
TEST_F(GemmaTest, Arithmetic) {
TEST_F(GemmaTest, DISABLED_Arithmetic) {
static const char* kQA[][2] = {
{"what is 13 + 14?", "27"},
{"what is 7 * 8", "56"},
@ -280,7 +280,7 @@ static const char kDeclaration[] = {
"reliance on the protection of divine Providence, we mutually pledge to "
"each other our Lives, our Fortunes and our sacred Honor.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
TEST_F(GemmaTest, DISABLED_CrossEntropySmall) {
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall);
@ -288,19 +288,19 @@ TEST_F(GemmaTest, CrossEntropySmall) {
EXPECT_LT(entropy, 1.6f);
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) {
float entropy = GemmaCrossEntropy(kJingleBells);
std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 2.3f);
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) {
float entropy = GemmaCrossEntropy(kGettysburg);
std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 1.2f);
}
TEST_F(GemmaTest, CrossEntropyDeclaration) {
TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) {
float entropy = GemmaCrossEntropy(kDeclaration);
std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 1.0f);

View File

@ -874,10 +874,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
m*theta_i. However in the Gemma implementation we choose to rotate
the pairs of dimensions v_{i} and v_{i + d//2} instead.
pos parameter is deliberately an int because in the backward pass we
call this with negative values (for the VJP calculation we need the transpose
of this rotation matrix which is simply the same matrix with -pos parameter)
*/
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
size_t dim_qkv, size_t pos) {
size_t dim_qkv, int pos) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
@ -898,7 +902,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
float* HWY_RESTRICT x,
size_t dim_qkv,
size_t pos) {
int pos) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {

112
gemma/weights.cc Normal file
View File

@ -0,0 +1,112 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/weights.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/stats.h"
namespace gcpp {
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return AllocateWeights<float, ConfigGemma2B>(pool);
case Model::GEMMA_7B:
return AllocateWeights<float, ConfigGemma7B>(pool);
case Model::GRIFFIN_2B:
return AllocateWeights<float, ConfigGriffin2B>(pool);
case Model::GEMMA_TINY:
return AllocateWeights<float, ConfigGemmaTiny>(pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
namespace {
template <typename TConfig>
void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) {
ZeroInit<float, TConfig>(
*reinterpret_cast<Weights<float, TConfig>*>(weights.get()));
}
} // namespace
void ZeroInitWeights(Model model, ByteStorageT& weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
ZeroInitWeightsT<ConfigGemma2B>(weights, pool);
break;
case Model::GEMMA_7B:
ZeroInitWeightsT<ConfigGemma7B>(weights, pool);
break;
case Model::GRIFFIN_2B:
ZeroInitWeightsT<ConfigGriffin2B>(weights, pool);
break;
case Model::GEMMA_TINY:
ZeroInitWeightsT<ConfigGemmaTiny>(weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
namespace {
void LogVec(const char* name, const float* data, size_t len) {
hwy::Stats stats;
for (size_t i = 0; i < len; ++i) {
stats.Notify(data[i]);
}
printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
name, len, stats.Min(), stats.Mean(), stats.Max());
}
class WeightLogger {
public:
template <size_t N>
void operator()(const char* name, const std::array<float, N>& tensor) {
LogVec(name, tensor.data(), N);
total_weights += N;
}
size_t total_weights = 0;
};
template <typename TConfig>
void LogWeightStats(const ByteStorageT& weights_u8) {
const auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightLogger logger;
ForEachTensor1<float, TConfig>(logger, weights);
printf("%-20s %12zu\n", "Total", logger.total_weights);
}
} // namespace
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
switch (model) {
case Model::GEMMA_2B:
return LogWeightStats<ConfigGemma2B>(weights);
case Model::GEMMA_7B:
return LogWeightStats<ConfigGemma7B>(weights);
case Model::GRIFFIN_2B:
return LogWeightStats<ConfigGriffin2B>(weights);
case Model::GEMMA_TINY:
return LogWeightStats<ConfigGemmaTiny>(weights);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

290
gemma/weights.h Normal file
View File

@ -0,0 +1,290 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
template <typename T, class TConfig>
struct Layer {
Layer() {}
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
union {
struct {
std::array<T, kAttVecEinsumWSize> attn_vec_einsum_w;
std::array<T, kQKVEinsumWSize> qkv_einsum_w;
std::array<T, kAOBiasDim> attention_output_biases;
};
struct {
std::array<T, kGriffinDim * kGriffinDim> linear_x_w;
std::array<T, kGriffinDim> linear_x_biases;
std::array<T, kGriffinDim * kGriffinDim> linear_y_w;
std::array<T, kGriffinDim> linear_y_biases;
std::array<T, kGriffinDim * kGriffinDim> linear_out_w;
std::array<T, kGriffinDim> linear_out_biases;
std::array<T, kConv1dWidth * kGriffinDim> conv_w;
std::array<T, kGriffinDim> conv_biases;
std::array<T, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
std::array<T, kGriffinDim * 2> gate_biases;
std::array<T, kGriffinDim> a;
} griffin;
};
std::array<T, kGatingEinsumWSize> gating_einsum_w;
std::array<T, kModelDim * kFFHiddenDim> linear_w;
std::array<T, kModelDim> pre_attention_norm_scale;
std::array<T, kModelDim> pre_ffw_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
template <class TConfig>
using LayerF = Layer<float, TConfig>;
// Array instead of single large allocation for parallel mem init. Split out of
// Weights so that only these pointers are initialized.
template <typename T, class TConfig>
struct LayerPointers {
explicit LayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->layers[task] = hwy::AllocateAligned<Layer<T, TConfig>>(1);
});
}
using TLayer = Layer<T, TConfig>;
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
};
template <typename T, class TConfig>
struct Weights {
// No ctor/dtor, allocated via AllocateAligned.
std::array<T, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
std::array<T, TConfig::kModelDim> final_norm_scale;
LayerPointers<T, TConfig> layer_ptrs;
std::array<T, TConfig::kNumTensorScales> scales;
const Layer<T, TConfig>* GetLayer(size_t layer) const {
return layer_ptrs.layers[layer].get();
}
Layer<T, TConfig>* GetLayer(size_t layer) {
return layer_ptrs.layers[layer].get();
}
};
template <class TConfig>
using WeightsF = Weights<float, TConfig>;
template <typename T, typename TConfig>
ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
using TWeights = Weights<T, TConfig>;
ByteStorageT weights_u8 = hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
return weights_u8;
}
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
#define GEMMA_CALL_TOP_FUNC2(name, member) \
func(name, weights1.member, weights2.member)
#define GEMMA_CALL_TOP_FUNC3(name, member) \
func(name, weights1.member, weights2.member, weights3.member)
#define GEMMA_CALL_TOP_FUNC4(name, member) \
func(name, weights1.member, weights2.memeber, \
weights3.member, weights4.member)
#define GEMMA_CALL_LAYER_FUNC1(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member)
#define GEMMA_CALL_LAYER_FUNC2(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member)
#define GEMMA_CALL_LAYER_FUNC3(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer3.member)
#define GEMMA_CALL_LAYER_FUNC4(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer4.member)
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \
if (type == LayerAttentionType::kGemma) { \
GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \
GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \
} else { \
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \
GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \
GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \
GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \
GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \
GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \
} \
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
if (TConfig::kPostNormScale) { \
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
} \
GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \
if (TConfig::kFFBiases) { \
GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \
GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \
} \
if (TConfig::kSoftmaxAttnOutputBiases && \
type == LayerAttentionType::kGemma) { \
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
}
template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(1)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(1)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
Weights<T, TConfig>& weights2) {
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(2)
}
}
#undef GEMMA_CALL_TOP_FUNC1
#undef GEMMA_CALL_TOP_FUNC2
#undef GEMMA_CALL_TOP_FUNC3
#undef GEMMA_CALL_TOP_FUNC4
#undef GEMMA_CALL_LAYER_FUNC1
#undef GEMMA_CALL_LAYER_FUNC2
#undef GEMMA_CALL_LAYER_FUNC3
#undef GEMMA_CALL_LAYER_FUNC4
#undef GEMMA_CALL_ALL_LAYER_FUNC
template<typename T, typename TConfig>
void ZeroInit(Weights<T, TConfig>& w) {
hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
}
}
template<typename T, typename TConfig>
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
sizeof(src.embedder_input_embedding));
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
sizeof(src.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i)));
}
}
// Owns weights and undoes the type erasure of AllocateWeights.
template<typename T, typename TConfig>
class WeightsWrapper {
public:
WeightsWrapper()
: pool_(0), data_(AllocateWeights<T, TConfig>(pool_)),
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
const Weights<T, TConfig>& get() const { return *weights_; }
Weights<T, TConfig>& get() { return *weights_; }
void clear() { ZeroInit(get()); }
void copy(const WeightsWrapper<T, TConfig>& other) {
Copy(get(), other.get());
}
private:
hwy::ThreadPool pool_;
ByteStorageT data_;
Weights<T, TConfig>* weights_;
};
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool);
void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool);
void LogWeightStats(Model model, const ByteStorageT& weights);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_