mirror of https://github.com/google/gemma.cpp.git
Merge pull request #203 from szabadka:backprop5
PiperOrigin-RevId: 640133430
This commit is contained in:
commit
25d9c8ff30
|
|
@ -59,8 +59,12 @@ cc_library(
|
|||
"gemma/gemma.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/common.h",
|
||||
"gemma/common-inl.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/weights.h",
|
||||
],
|
||||
deps = [
|
||||
":ops",
|
||||
|
|
|
|||
|
|
@ -46,10 +46,28 @@ set(SOURCES
|
|||
compression/sfp.h
|
||||
compression/sfp-inl.h
|
||||
compression/test_util.h
|
||||
backprop/backward.cc
|
||||
backprop/backward.h
|
||||
backprop/backward-inl.h
|
||||
backprop/backward_scalar.h
|
||||
backprop/common_scalar.cc
|
||||
backprop/common_scalar.h
|
||||
backprop/forward.cc
|
||||
backprop/forward.h
|
||||
backprop/forward-inl.h
|
||||
backprop/forward_scalar.h
|
||||
backprop/optimizer.cc
|
||||
backprop/optimizer.h
|
||||
gemma/configs.h
|
||||
gemma/activations.cc
|
||||
gemma/activations.h
|
||||
gemma/common.h
|
||||
gemma/common-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/ops.h
|
||||
gemma/weights.cc
|
||||
gemma/weights.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
)
|
||||
|
|
@ -100,7 +118,13 @@ target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohm
|
|||
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
||||
if (GEMMA_ENABLE_TESTS)
|
||||
|
||||
enable_testing()
|
||||
include(GoogleTest)
|
||||
|
||||
set(GEMMA_TEST_FILES
|
||||
backprop/backward_test.cc
|
||||
backprop/backward_scalar_test.cc
|
||||
backprop/optimize_test.cc
|
||||
gemma/ops_test.cc
|
||||
gemma/gemma_test.cc
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,417 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Implementation of the Vector-Jacobian Products (VJP) of the individual
|
||||
// operations of the forward pass.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "gemma/common-inl.h"
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <size_t kCols, size_t kRows>
|
||||
void MatMulVJP(const std::array<float, kRows * kCols>& weights,
|
||||
const float* HWY_RESTRICT x, // num_tokens * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t num_tokens,
|
||||
std::array<float, kRows * kCols>& grad_w,
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t voffs = pos * kRows;
|
||||
const size_t xoffs = pos * kCols;
|
||||
for (size_t j = 0; j < kRows; ++j) {
|
||||
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols);
|
||||
MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs],
|
||||
kCols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kHeads, size_t kCols, size_t kRows>
|
||||
void MultiHeadMatMulVJP(
|
||||
const std::array<float, kHeads * kRows * kCols>& weights,
|
||||
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t num_tokens,
|
||||
std::array<float, kHeads * kRows * kCols>& grad_w,
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t j = 0; j < kRows; ++j) {
|
||||
for (size_t h = 0; h < kHeads; ++h) {
|
||||
MulByConstAndAdd(v[pos * kRows + j],
|
||||
&x[pos * kHeads * kCols + h * kCols],
|
||||
&grad_w[h * kRows * kCols + j * kCols], kCols);
|
||||
MulByConstAndAdd(v[pos * kRows + j],
|
||||
&weights[h * kRows * kCols + j * kCols],
|
||||
&grad_x[pos * kHeads * kCols + h * kCols], kCols);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
|
||||
// kSqrtOverPi*3*kMul
|
||||
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
|
||||
|
||||
const hn::Vec<D> v2 = hn::Mul(v, v);
|
||||
const hn::Vec<D> v3 = hn::Mul(v2, v);
|
||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||
const hn::Vec<D> tanh = hn::Tanh(d, arg);
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
|
||||
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
|
||||
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
|
||||
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto offset =
|
||||
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
|
||||
hn::Transform1(
|
||||
d, backward, size, forward,
|
||||
[&offset](const auto d, const auto v, const auto y)
|
||||
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void RMSNormVJP(
|
||||
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
|
||||
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
|
||||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x + offset, model_dim);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(model_dim) + eps);
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
grad_w[i] += v[offset + i] * x[offset + i] * ss;
|
||||
}
|
||||
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
|
||||
float tmp = 0.0f;
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
|
||||
}
|
||||
tmp *= ss3;
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
|
||||
tmp * x[offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void InputEmbeddingVJP(
|
||||
const float* weights, const std::vector<int>& prompt,
|
||||
const float scaling, const float* HWY_RESTRICT v,
|
||||
float* HWY_RESTRICT grad, size_t model_dim) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
MulByConstAndAdd(scaling, v + pos * model_dim,
|
||||
grad + token * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TConfig>
|
||||
void LayerVJP(const Layer<float, TConfig>& weights,
|
||||
const ForwardLayer<float, TConfig>& forward,
|
||||
const float* HWY_RESTRICT next_layer_grad,
|
||||
size_t num_tokens,
|
||||
Layer<float, TConfig>& grad,
|
||||
ForwardLayer<float, TConfig>& backward,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
|
||||
MatMulVJP<kFFHiddenDim, kModelDim>(
|
||||
weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad,
|
||||
num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(),
|
||||
pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
|
||||
const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim;
|
||||
const float* HWY_RESTRICT b_out_gated =
|
||||
backward.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
|
||||
float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
||||
const auto y = Load(df, f_out + i);
|
||||
const auto x = Load(df, f_out_mul + i);
|
||||
const auto v = Load(df, b_out_gated + i);
|
||||
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
|
||||
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
|
||||
weights.gating_einsum_w,
|
||||
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
|
||||
num_tokens, grad.gating_einsum_w,
|
||||
backward.bf_pre_ffw_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
|
||||
forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(),
|
||||
kModelDim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.data(),
|
||||
backward.attention_out.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(next_layer_grad + pos * kModelDim,
|
||||
backward.attention_out.data() + pos * kModelDim, kModelDim);
|
||||
}
|
||||
|
||||
hwy::ZeroBytes(backward.qkv.data(),
|
||||
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
|
||||
|
||||
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
|
||||
weights.attn_vec_einsum_w, forward.att_out.data(),
|
||||
backward.attention_out.data(), num_tokens,
|
||||
grad.attn_vec_einsum_w, backward.att_out.data(), pool);
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
const float* HWY_RESTRICT b_att_out =
|
||||
backward.att_out.data() + (pos * kHeads + head) * kQKVDim;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
|
||||
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
|
||||
b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim);
|
||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
|
||||
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
|
||||
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
|
||||
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
|
||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||
float* HWY_RESTRICT b_kv =
|
||||
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(b_kv, kQKVDim, -pos);
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT b_q =
|
||||
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
MulByConst(kQueryScale, b_q, kQKVDim);
|
||||
Rope(b_q, kQKVDim, -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
|
||||
weights.qkv_einsum_w, forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), num_tokens,
|
||||
grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.data(),
|
||||
forward.input.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
kModelDim, num_tokens,
|
||||
grad.pre_attention_norm_scale.data(),
|
||||
backward.input.data(), pool);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(backward.attention_out.data() + pos * kModelDim,
|
||||
backward.input.data() + pos * kModelDim, kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const float cap,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto one = hn::Set(d, 1.0f);
|
||||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
|
||||
hn::Transform1(
|
||||
d, backward, size, forward,
|
||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||
const auto scaled = hn::Mul(vinv_cap, y);
|
||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
|
||||
const Prompt& prompt, size_t vocab_size) {
|
||||
HWY_ASSERT(!prompt.tokens.empty());
|
||||
const float scaling = -1.0 / std::log(2.0);
|
||||
size_t num_tokens = prompt.tokens.size() - 1;
|
||||
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if (pos + 1 < prompt.context_size) {
|
||||
continue;
|
||||
}
|
||||
const int next_token = prompt.tokens[pos + 1];
|
||||
grad[pos * vocab_size + next_token] =
|
||||
scaling / x[pos * vocab_size + next_token];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TConfig>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const Weights<float, TConfig>& weights,
|
||||
const ForwardPass<float, TConfig>& forward,
|
||||
Weights<float, TConfig>& grad,
|
||||
ForwardPass<float, TConfig>& backward,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
static_assert(!TConfig::kAbsolutePE);
|
||||
static_assert(!TConfig::kPostNormScale);
|
||||
static_assert(TConfig::kKVHeads == 1);
|
||||
|
||||
HWY_DASSERT(prompt.context_size > 0);
|
||||
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
kVocabSize);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize,
|
||||
kVocabSize);
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(forward.logits.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize);
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, kVocabSize>(
|
||||
weights.embedder_input_embedding, forward.final_norm_output.data(),
|
||||
backward.logits.data(), num_tokens,
|
||||
grad.embedder_input_embedding, backward.final_norm_output.data(),
|
||||
pool);
|
||||
|
||||
RMSNormVJP(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(),
|
||||
kModelDim, num_tokens,
|
||||
grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), pool);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
// TODO(szabadka) Implement Griffin layer vjp.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
float* next_layer_grad = layer + 1 < kLayers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
|
||||
}
|
||||
|
||||
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), kModelDim);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/backward.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <typename TConfig>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ByteStorageT& weights_u8,
|
||||
const ByteStorageT& forward_u8,
|
||||
ByteStorageT& grad_u8,
|
||||
ByteStorageT& backward_u8,
|
||||
hwy::ThreadPool& pool) {
|
||||
using TWeights = WeightsF<TConfig>;
|
||||
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
|
||||
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
|
||||
using TAct = ForwardPass<float, TConfig>;
|
||||
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
|
||||
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
|
||||
CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool);
|
||||
}
|
||||
|
||||
void CrossEntropyLossBackwardPassT(Model model,
|
||||
const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
const ByteStorageT& forward,
|
||||
ByteStorageT& grad,
|
||||
ByteStorageT& backward,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CrossEntropyLossBackwardPass<ConfigGemma2B>(
|
||||
prompt, weights, forward, grad, backward, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
CrossEntropyLossBackwardPass<ConfigGemmaTiny>(
|
||||
prompt, weights, forward, grad, backward, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CrossEntropyLossBackwardPassT);
|
||||
|
||||
void CrossEntropyLossBackwardPass(
|
||||
const Model& model, const Prompt& prompt,
|
||||
const ByteStorageT& weights, const ByteStorageT& forward,
|
||||
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
||||
model, prompt, weights, forward, grad, backward, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void CrossEntropyLossBackwardPass(
|
||||
const Model& model, const Prompt& prompt,
|
||||
const ByteStorageT& weights, const ByteStorageT& forward,
|
||||
ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
|
@ -0,0 +1,353 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
template<typename T>
|
||||
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t N, size_t M, size_t K) {
|
||||
memset(dx, 0, M * K * sizeof(dx[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
|
||||
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t H, size_t N, size_t M, size_t K) {
|
||||
memset(dx, 0, H * M * K * sizeof(dx[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
for (size_t h = 0; h < H; ++h) {
|
||||
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
|
||||
&dw[h * N * M + j * M], M);
|
||||
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
|
||||
&dx[i * H * M + h * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const size_t offset = i * N;
|
||||
constexpr T eps(1e-6);
|
||||
T ss = SquaredL2(x + i * N, N);
|
||||
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
|
||||
}
|
||||
const T ss3 = ss * ss * ss / T(N);
|
||||
T tmp = 0.0;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
|
||||
}
|
||||
tmp *= ss3;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += y[i] * dy[i];
|
||||
}
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
dy[i] = y[i] * (dy[i] - sum);
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
SoftmaxVJPT(y + i * N, dy + i * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T GeluDerivative(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
|
||||
|
||||
const T x2 = x * x;
|
||||
const T x3 = x2 * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T tanh = std::tanh(arg);
|
||||
const T cdf = T(0.5) * (T(1.0) + tanh);
|
||||
const T dtanh = T(1.0) - tanh * tanh;
|
||||
const T darg = kMul2 * x2 + kSqrt2OverPi;
|
||||
return T(0.5) * x * dtanh * darg + cdf;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const T* x1 = in + i * 2 * N;
|
||||
const T* x2 = x1 + N;
|
||||
const T* v = d_out + i * N;
|
||||
T* dx1 = d_in + i * 2 * N;
|
||||
T* dx2 = dx1 + N;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
|
||||
dx2[j] = v[j] * Gelu(x1[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
|
||||
size_t num_tokens, size_t kHeads, size_t kQKVDim,
|
||||
size_t kSeqLen) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * (kHeads + 2) * kQKVDim;
|
||||
memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const T* q = qkv + qoffs;
|
||||
const T* dout = doutput + aoffs;
|
||||
T* dq = dqkv + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const T* k = qkv + koffs;
|
||||
T* dk = dqkv + koffs;
|
||||
MulByConstAndAddT(dout[pos2], k, dq, kQKVDim);
|
||||
MulByConstAndAddT(dout[pos2], q, dk, kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens,
|
||||
size_t kHeads, size_t kSeqLen) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
|
||||
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
|
||||
memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
|
||||
T* dqkv, T* dattention, size_t num_tokens,
|
||||
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
|
||||
auto v_offset = [&](size_t pos) {
|
||||
return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
};
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim;
|
||||
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const T* att = &attention[aoffset];
|
||||
const T* dout = &doutput[offset];
|
||||
T* datt = &dattention[aoffset];
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim);
|
||||
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
|
||||
const T* dy, T* dw, size_t N) {
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
int token = tokens[i];
|
||||
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void LayerVJP(const Layer<T, TConfig>& weights,
|
||||
const ForwardLayer<T, TConfig>& forward,
|
||||
const T* dy,
|
||||
Layer<T, TConfig>& grad,
|
||||
ForwardLayer<T, TConfig>& backward,
|
||||
size_t num_tokens) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim));
|
||||
|
||||
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
|
||||
dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
|
||||
kModelDim, kFFHiddenDim, num_tokens);
|
||||
|
||||
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
|
||||
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(),
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
kModelDim, num_tokens);
|
||||
|
||||
AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim);
|
||||
|
||||
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(),
|
||||
grad.attn_vec_einsum_w.data(),
|
||||
backward.att_out.data(),
|
||||
kHeads, kModelDim, kQKVDim, num_tokens);
|
||||
|
||||
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
|
||||
backward.att_out.data(), backward.qkv.data(),
|
||||
backward.att.data(), num_tokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
|
||||
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(),
|
||||
num_tokens, kHeads, kSeqLen);
|
||||
|
||||
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
|
||||
backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * kQKVDim, kQKVDim, -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), grad.qkv_einsum_w.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
(kHeads + 2) * kQKVDim, kModelDim, num_tokens);
|
||||
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
grad.pre_attention_norm_scale.data(),
|
||||
backward.input.data(), kModelDim, num_tokens);
|
||||
|
||||
AddFromT(backward.attention_out.data(), backward.input.data(),
|
||||
num_tokens * kModelDim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
||||
T cap = 30.0;
|
||||
T inv_cap = T(1.0) / cap;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
T scaled = y[i] * inv_cap;
|
||||
dy[i] *= (T(1.0) - scaled * scaled);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
|
||||
T scaling = -1.0 / std::log(2.0);
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
memset(dx, 0, V * num_tokens * sizeof(x[0]));
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
if (i + 1 < prompt.context_size) {
|
||||
continue;
|
||||
}
|
||||
const int next_token = tokens[i + 1];
|
||||
dx[i * V + next_token] = scaling / x[i * V + next_token];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const Weights<T, TConfig>& weights,
|
||||
const ForwardPass<T, TConfig>& forward,
|
||||
Weights<T, TConfig>& grad,
|
||||
ForwardPass<T, TConfig>& backward) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
kVocabSize);
|
||||
|
||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
||||
kVocabSize, num_tokens);
|
||||
|
||||
SoftcapVJPT(forward.logits.data(), backward.logits.data(),
|
||||
num_tokens * kVocabSize);
|
||||
|
||||
MatMulVJPT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(),
|
||||
backward.logits.data(),
|
||||
grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(),
|
||||
kVocabSize, kModelDim, num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(),
|
||||
grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), kModelDim, num_tokens);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
T* next_layer_grad = layer + 1 < kLayers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
||||
}
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
|
||||
tokens, kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), kModelDim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
|
|
@ -0,0 +1,610 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/backward_scalar.h"
|
||||
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <random>
|
||||
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(BackPropTest, MatMulVJP) {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, kRows * kCols> weights;
|
||||
std::array<T, kTokens * kCols> x;
|
||||
std::array<T, kRows * kCols> grad;
|
||||
std::array<T, kTokens * kCols> dx;
|
||||
std::array<TC, kRows * kCols> c_weights;
|
||||
std::array<TC, kTokens * kCols> c_x;
|
||||
std::array<TC, kTokens * kRows> c_y;
|
||||
std::array<T, kTokens * kRows> dy;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
memset(&grad, 0, sizeof(grad));
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
||||
kRows, kCols, kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||
static const size_t kRows = 2;
|
||||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, kRows * kCols * kHeads> weights;
|
||||
std::array<T, kTokens * kCols * kHeads> x;
|
||||
std::array<T, kRows * kCols * kHeads> grad;
|
||||
std::array<T, kTokens * kCols * kHeads> dx;
|
||||
std::array<TC, kRows * kCols * kHeads> c_weights;
|
||||
std::array<TC, kTokens * kCols * kHeads> c_x;
|
||||
std::array<TC, kTokens * kRows> c_y;
|
||||
std::array<T, kTokens * kRows> dy;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
memset(&grad, 0, sizeof(grad));
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
|
||||
dx.data(), kHeads, kRows, kCols, kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, RMSNormVJP) {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, N> weights;
|
||||
std::array<T, N> grad;
|
||||
std::array<T, K * N> x;
|
||||
std::array<T, K * N> dx;
|
||||
std::array<T, K * N> dy;
|
||||
std::array<TC, N> c_weights;
|
||||
std::array<TC, K * N> c_x;
|
||||
std::array<TC, K * N> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), K * N);
|
||||
};
|
||||
memset(&grad, 0, sizeof(grad));
|
||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
||||
N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, SoftmaxVJP) {
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, N> x;
|
||||
std::array<T, N> dx;
|
||||
std::array<T, N> dy;
|
||||
std::array<TC, N> c_x;
|
||||
std::array<TC, N> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), sizeof(c_x));
|
||||
Softmax(c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
Softmax(x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
||||
SoftmaxVJPT(x.data(), dx.data(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MaskedSoftmaxVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t N = kHeads * kSeqLen * kSeqLen;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, N> x;
|
||||
std::array<T, N> dy;
|
||||
std::array<T, N> dx = {};
|
||||
std::array<TC, N> c_x;
|
||||
std::array<TC, N> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(),
|
||||
kTokens * kHeads * kSeqLen * sizeof(c_x[0]));
|
||||
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
|
||||
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx[0]));
|
||||
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, SoftcapVJP) {
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, N> x;
|
||||
std::array<T, N> dx;
|
||||
std::array<T, N> dy;
|
||||
std::array<TC, N> c_x;
|
||||
std::array<TC, N> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0]));
|
||||
Softcap(c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
Softcap(x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
||||
SoftcapVJPT(x.data(), dx.data(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||
static const size_t K = 8;
|
||||
static const size_t V = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, K * V> x;
|
||||
std::array<T, K * V> dx;
|
||||
std::array<TC, K * V> c_x;
|
||||
Prompt prompt;
|
||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
prompt.context_size = 1 + (iter % 6);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Softcap(x.data(), V * K);
|
||||
Softmax(x.data(), V, K);
|
||||
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLoss(c_x.data(), prompt, V);
|
||||
};
|
||||
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, GatedGeluVJP) {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, K * 2 * N> x;
|
||||
std::array<T, K * 2 * N> dx;
|
||||
std::array<T, K * N> dy;
|
||||
std::array<TC, K * 2 * N> c_x;
|
||||
std::array<TC, K * N> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
GatedGelu(c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), N * K);
|
||||
};
|
||||
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MaskedAttentionVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kQKVDim = 8;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
|
||||
static const size_t kOutSize = kSeqLen * kHeads * kSeqLen;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, kQKVSize> x;
|
||||
std::array<T, kQKVSize> dx = {};
|
||||
std::array<T, kOutSize> dy;
|
||||
std::array<TC, kQKVSize> c_x;
|
||||
std::array<TC, kOutSize> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
||||
};
|
||||
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
|
||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MixByAttentionVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kQKVDim = 8;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
|
||||
static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen;
|
||||
static const size_t kOutSize = kSeqLen * kHeads * kQKVDim;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, kQKVSize> qkv;
|
||||
std::array<T, kQKVSize> dqkv = {};
|
||||
std::array<T, kAttnSize> attn;
|
||||
std::array<T, kAttnSize> dattn = {};
|
||||
std::array<T, kOutSize> dy;
|
||||
std::array<TC, kQKVSize> c_qkv;
|
||||
std::array<TC, kAttnSize> c_attn;
|
||||
std::array<TC, kOutSize> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(qkv, 1.0, gen);
|
||||
RandInit(attn, 1.0, gen);
|
||||
Complexify(qkv, c_qkv);
|
||||
Complexify(attn, c_attn);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
|
||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
||||
};
|
||||
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
|
||||
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
|
||||
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, InputEmbeddingVJP) {
|
||||
static const size_t kSeqLen = 8;
|
||||
static const size_t kVocabSize = 4;
|
||||
static const size_t kModelDim = 16;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
std::array<T, kVocabSize * kModelDim> weights;
|
||||
std::array<T, kVocabSize * kModelDim> grad;
|
||||
std::array<T, kSeqLen * kModelDim> dy;
|
||||
std::array<TC, kVocabSize * kModelDim> c_weights;
|
||||
std::array<TC, kSeqLen * kModelDim> c_y;
|
||||
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
|
||||
size_t num_tokens = tokens.size() - 1;
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0, gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
auto func = [&]() {
|
||||
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
|
||||
};
|
||||
memset(&grad, 0, sizeof(grad));
|
||||
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
|
||||
kModelDim);
|
||||
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
struct TestConfig {
|
||||
static constexpr int kSeqLen = 18;
|
||||
static constexpr int kVocabSize = 12;
|
||||
static constexpr int kModelDim = 32;
|
||||
static constexpr int kHeads = 3;
|
||||
static constexpr int kQKVDim = 12;
|
||||
static constexpr int kFFHiddenDim = 48;
|
||||
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
|
||||
FixedLayerConfig<2>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
TEST(BackPropTest, LayerVJP) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim;
|
||||
Layer<T, TestConfig> weights;
|
||||
Layer<T, TestConfig> grad;
|
||||
ForwardLayer<T, TestConfig> forward;
|
||||
ForwardLayer<T, TestConfig> backward = {};
|
||||
Layer<TC, TestConfig> c_weights;
|
||||
ForwardLayer<TC, TestConfig> c_forward;
|
||||
std::array<T, kOutputSize> y;
|
||||
std::array<T, kOutputSize> dy;
|
||||
std::array<TC, kOutputSize> c_y;
|
||||
const size_t num_tokens = 3;
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0, gen);
|
||||
RandInit(forward.input, 1.0, gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(forward.input, c_forward.input);
|
||||
auto func = [&]() {
|
||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim);
|
||||
};
|
||||
memset(&grad, 0, sizeof(grad));
|
||||
ApplyLayer(weights, forward, num_tokens, y.data());
|
||||
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
|
||||
TestGradient(backward.input, c_forward.input, func, 1e-11, 1e-11,
|
||||
__LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-11);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, EndToEnd) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
WeightsWrapper<T, TestConfig> weights;
|
||||
WeightsWrapper<T, TestConfig> grad;
|
||||
ForwardPass<T, TestConfig> forward;
|
||||
ForwardPass<T, TestConfig> backward;
|
||||
WeightsWrapper<TC, TestConfig> c_weights;
|
||||
ForwardPass<TC, TestConfig> c_forward;
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(10, gen);
|
||||
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0, gen);
|
||||
CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
grad.clear();
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward, grad.get(), backward);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 1e-11);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void MulByConstAndAddT(T c, const Layer<T, TConfig>& x,
|
||||
Layer<T, TConfig>& out) {
|
||||
MulByConstAndAddT(c, x.pre_attention_norm_scale,
|
||||
out.pre_attention_norm_scale);
|
||||
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
|
||||
MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w);
|
||||
MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale);
|
||||
MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w);
|
||||
MulByConstAndAddT(c, x.linear_w, out.linear_w);
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void MulByConstAndAddT(T c, const Weights<T, TConfig>& x,
|
||||
Weights<T, TConfig>& out) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
MulByConstAndAddT(c, x.embedder_input_embedding,
|
||||
out.embedder_input_embedding);
|
||||
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch.
|
||||
template<typename T, typename TConfig>
|
||||
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<T, TConfig>& weights,
|
||||
ForwardPass<T, TConfig>& forward) {
|
||||
T loss = 0.0;
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
}
|
||||
T scale = 1.0 / batch.size();
|
||||
return loss * scale;
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch by applying gradient with the given
|
||||
// learning rate. Does not update weights, but uses the given tmp weights
|
||||
// instead.
|
||||
template<typename T, typename TConfig>
|
||||
T CrossEntropyLossForwardPass(T learning_rate,
|
||||
const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<T, TConfig>& weights,
|
||||
const WeightsWrapper<T, TConfig>& grad,
|
||||
WeightsWrapper<T, TConfig>& tmp,
|
||||
ForwardPass<T, TConfig>& forward) {
|
||||
tmp.copy(weights);
|
||||
const T scale = -learning_rate / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), tmp.get());
|
||||
return CrossEntropyLossForwardPass(batch, tmp, forward);
|
||||
}
|
||||
|
||||
// Uses line search in the negative gradient direction to update weights. We do
|
||||
// this so that we can test that each step during the gradient descent can
|
||||
// decrease the objective function value.
|
||||
template<typename T, typename TConfig>
|
||||
T FindOptimalUpdate(const WeightsWrapper<T, TConfig>& grad,
|
||||
WeightsWrapper<T, TConfig>& weights,
|
||||
WeightsWrapper<T, TConfig>& tmp,
|
||||
ForwardPass<T, TConfig>& forward,
|
||||
const std::vector<Prompt>& batch,
|
||||
T loss, T initial_learning_rate) {
|
||||
T lr0 = initial_learning_rate;
|
||||
T loss0 = CrossEntropyLossForwardPass(
|
||||
lr0, batch, weights, grad, tmp, forward);
|
||||
for (size_t iter = 0; iter < 30; ++iter) {
|
||||
T lr1 = lr0 * 0.5;
|
||||
T loss1 = CrossEntropyLossForwardPass(
|
||||
lr1, batch, weights, grad, tmp, forward);
|
||||
if (loss0 < loss && loss1 >= loss0) {
|
||||
break;
|
||||
}
|
||||
loss0 = loss1;
|
||||
lr0 = lr1;
|
||||
}
|
||||
for (size_t iter = 0; iter < 30; ++iter) {
|
||||
T lr1 = lr0 * 2.0;
|
||||
T loss1 = CrossEntropyLossForwardPass(
|
||||
lr1, batch, weights, grad, tmp, forward);
|
||||
if (loss1 >= loss0) {
|
||||
break;
|
||||
}
|
||||
loss0 = loss1;
|
||||
lr0 = lr1;
|
||||
}
|
||||
const T scale = -lr0 / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), weights.get());
|
||||
return lr0;
|
||||
}
|
||||
|
||||
TEST(BackProptest, Convergence) {
|
||||
std::mt19937 gen(42);
|
||||
using T = float;
|
||||
using TC = std::complex<double>;
|
||||
WeightsWrapper<T, TestConfig> weights;
|
||||
WeightsWrapper<T, TestConfig> grad;
|
||||
WeightsWrapper<T, TestConfig> tmp;
|
||||
ForwardPass<T, TestConfig> forward;
|
||||
ForwardPass<T, TestConfig> backward;
|
||||
WeightsWrapper<TC, TestConfig> c_weights;
|
||||
ForwardPass<TC, TestConfig> c_forward;
|
||||
constexpr size_t kBatchSize = 5;
|
||||
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
|
||||
T learning_rate = 0.01;
|
||||
|
||||
RandInit(weights.get(), T(1.0), gen);
|
||||
|
||||
printf("Sample batch:\n");
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
ReverseSequenceSampler::LogPrompt(training_task.Sample(gen));
|
||||
}
|
||||
|
||||
T prev_loss = std::numeric_limits<T>::max();
|
||||
bool stop = false;
|
||||
size_t step = 0;
|
||||
while (!stop) {
|
||||
T loss = 0.0;
|
||||
grad.clear();
|
||||
std::mt19937 sgen(42);
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(kBatchSize, sgen);
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward, grad.get(), backward);
|
||||
}
|
||||
|
||||
if (step % 250 == 0) {
|
||||
printf("Checking gradient...\n");
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
TC scale = batch.size();
|
||||
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 5e-3f);
|
||||
}
|
||||
|
||||
loss /= batch.size();
|
||||
EXPECT_LT(loss, prev_loss);
|
||||
stop = step >= 10000 || loss < 1e-2;
|
||||
if (step % 10 == 0 || stop) {
|
||||
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
|
||||
step, loss, learning_rate);
|
||||
}
|
||||
if (!stop) {
|
||||
learning_rate = FindOptimalUpdate(
|
||||
grad, weights, tmp, forward, batch, loss, learning_rate);
|
||||
++step;
|
||||
}
|
||||
prev_loss = loss;
|
||||
}
|
||||
EXPECT_LT(step, 1000);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/backward_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
void TestMatMulVJP() {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
hwy::ThreadPool pool(8);
|
||||
std::mt19937 gen(42);
|
||||
HWY_ALIGN std::array<float, kRows * kCols> weights;
|
||||
HWY_ALIGN std::array<float, kTokens * kCols> x;
|
||||
HWY_ALIGN std::array<float, kTokens * kRows> dy;
|
||||
HWY_ALIGN std::array<float, kRows * kCols> grad;
|
||||
HWY_ALIGN std::array<float, kTokens * kCols> dx;
|
||||
HWY_ALIGN std::array<float, kRows * kCols> grad_scalar;
|
||||
HWY_ALIGN std::array<float, kTokens * kCols> dx_scalar;
|
||||
using TC = std::complex<double>;
|
||||
std::array<TC, kRows * kCols> c_weights;
|
||||
std::array<TC, kTokens * kCols> c_x;
|
||||
std::array<TC, kTokens * kRows> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
|
||||
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
|
||||
grad, dx.data(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||
|
||||
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestMultiHeadMatMulVJP() {
|
||||
static const size_t kRows = 2;
|
||||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
hwy::ThreadPool pool(8);
|
||||
std::mt19937 gen(42);
|
||||
HWY_ALIGN std::array<float, kRows * kCols * kHeads> weights;
|
||||
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> x;
|
||||
HWY_ALIGN std::array<float, kRows * kCols * kHeads> grad;
|
||||
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> dx;
|
||||
HWY_ALIGN std::array<float, kTokens * kRows> dy;
|
||||
HWY_ALIGN std::array<float, kRows * kCols * kHeads> grad_scalar;
|
||||
HWY_ALIGN std::array<float, kTokens * kCols * kHeads> dx_scalar;
|
||||
using TC = std::complex<double>;
|
||||
std::array<TC, kRows * kCols * kHeads> c_weights;
|
||||
std::array<TC, kTokens * kCols * kHeads> c_x;
|
||||
std::array<TC, kTokens * kRows> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
|
||||
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
|
||||
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||
|
||||
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestRMSNormVJP() {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
hwy::ThreadPool pool(8);
|
||||
std::mt19937 gen(42);
|
||||
HWY_ALIGN std::array<float, N> weights;
|
||||
HWY_ALIGN std::array<float, K * N> x;
|
||||
HWY_ALIGN std::array<float, N> grad;
|
||||
HWY_ALIGN std::array<float, K * N> dx;
|
||||
HWY_ALIGN std::array<float, K * N> dy;
|
||||
HWY_ALIGN std::array<float, N> grad_scalar;
|
||||
HWY_ALIGN std::array<float, K * N> dx_scalar;
|
||||
using TC = std::complex<double>;
|
||||
std::array<TC, N> c_weights;
|
||||
std::array<TC, K * N> c_x;
|
||||
std::array<TC, K * N> c_y;
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), K * N);
|
||||
};
|
||||
|
||||
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
||||
dx.data(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||
|
||||
hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar));
|
||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), N, K);
|
||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
struct TestConfig {
|
||||
static constexpr int kSeqLen = 24;
|
||||
static constexpr int kVocabSize = 16;
|
||||
static constexpr int kModelDim = 32;
|
||||
static constexpr int kHeads = 3;
|
||||
static constexpr int kQKVDim = 16;
|
||||
static constexpr int kFFHiddenDim = 64;
|
||||
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
|
||||
FixedLayerConfig<2>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
void TestEndToEnd() {
|
||||
std::mt19937 gen(42);
|
||||
hwy::ThreadPool pool(0);
|
||||
WeightsWrapper<float, TestConfig> weights;
|
||||
WeightsWrapper<float, TestConfig> grad;
|
||||
ActivationsWrapper<float, TestConfig> forward0;
|
||||
ActivationsWrapper<float, TestConfig> forward1;
|
||||
ActivationsWrapper<float, TestConfig> backward;
|
||||
using TC = std::complex<double>;
|
||||
WeightsWrapper<TC, TestConfig> c_weights;
|
||||
ForwardPass<TC, TestConfig> c_forward;
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(10, gen);
|
||||
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0f, gen);
|
||||
|
||||
float loss0 = CrossEntropyLossForwardPass(
|
||||
prompt, weights.get(), forward0.get());
|
||||
|
||||
float loss1 = CrossEntropyLossForwardPass<TestConfig, WeightsF, LayerF>(
|
||||
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
|
||||
pool);
|
||||
|
||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 1e-5);
|
||||
|
||||
grad.clear();
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
|
||||
pool);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 2e-3f);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(BackwardTest);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
float EmbeddingScaling(int model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(EmbeddingScaling);
|
||||
|
||||
float EmbeddingScaling(int model_dim) {
|
||||
return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template<typename T, typename U>
|
||||
U DotT(const T* a, const U* b, size_t N) {
|
||||
U sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<>
|
||||
std::complex<double> DotT(const float* a, const std::complex<double>* b,
|
||||
size_t N) {
|
||||
std::complex<double> sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += static_cast<double>(a[i]) * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MulByConstT(T c, T* x, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
// out += c * x
|
||||
template<typename T>
|
||||
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += c * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, size_t N>
|
||||
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
|
||||
MulByConstAndAddT(c, x.data(), out.data(), N);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void AddFromT(const T* a, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += a[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T SquaredL2(const T* x, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += x[i] * x[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T Gelu(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
|
||||
const T x3 = x * x * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template<typename T, typename U>
|
||||
void Rope(T* x, U base, size_t N, int i) {
|
||||
const size_t N2 = N / 2;
|
||||
for (size_t dim = 0; dim < N2; ++dim) {
|
||||
const T freq_exponents = T(2 * dim) / T(N);
|
||||
const T timescale = std::pow(base, freq_exponents);
|
||||
const T theta = T(i) / timescale;
|
||||
const T cos_val = std::cos(theta);
|
||||
const T sin_val = std::sin(theta);
|
||||
const T x0 = x[dim];
|
||||
const T x1 = x[dim + N2];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + N2] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(T* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(std::complex<T>* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
float EmbeddingScaling(int model_dim);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
|
@ -0,0 +1,292 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "gemma/common-inl.h"
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <typename ArrayT>
|
||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
||||
const float scaling, float* HWY_RESTRICT output,
|
||||
size_t model_dim) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
|
||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename WT, typename XT, typename OutT>
|
||||
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
|
||||
size_t model_dim, size_t num_tokens,
|
||||
OutT* HWY_RESTRICT output,
|
||||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
RMSNorm(x + offset, weights, output + offset, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
|
||||
const std::vector<int>& prompt,
|
||||
size_t context_size,
|
||||
size_t vocab_size,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
float loss = 0.0f;
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
if (pos + 1 < context_size) {
|
||||
continue; // next token is part of context, don't try to predict it
|
||||
}
|
||||
const int next_token = prompt[pos + 1];
|
||||
loss += std::log(probs[pos * vocab_size + next_token]);
|
||||
}
|
||||
float scaling = -1.0 / std::log(2.0);
|
||||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename TConfig, template<typename> typename LayerT>
|
||||
void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||
ForwardLayer<float, TConfig>& activations,
|
||||
size_t num_tokens,
|
||||
float* HWY_RESTRICT output,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
|
||||
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
|
||||
activations.input.data(), kModelDim, num_tokens,
|
||||
activations.pre_att_rms_out.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
|
||||
weights.qkv_einsum_w, 0,
|
||||
activations.pre_att_rms_out.data() + pos * kModelDim, nullptr,
|
||||
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(k, kQKVDim, pos);
|
||||
}
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
Rope(q, kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const float* HWY_RESTRICT k2 =
|
||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
Softmax(head_att, pos + 1);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
float* HWY_RESTRICT v2 =
|
||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
});
|
||||
|
||||
hwy::ZeroBytes(activations.attention_out.data(),
|
||||
num_tokens * kModelDim * sizeof(activations.attention_out[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
||||
nullptr, activations.att_post1.data() + pos * kModelDim, pool);
|
||||
AddFrom(activations.att_post1.data() + pos * kModelDim,
|
||||
activations.attention_out.data() + pos * kModelDim, kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.input.data() + pos * kModelDim,
|
||||
activations.attention_out.data() + pos * kModelDim, kModelDim);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
|
||||
activations.attention_out.data(), kModelDim, num_tokens,
|
||||
activations.bf_pre_ffw_rms_out.data(), pool);
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<kFFHiddenDim * 2, kModelDim>(
|
||||
weights.gating_einsum_w, 0,
|
||||
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr,
|
||||
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||
const float* HWY_RESTRICT out =
|
||||
activations.ffw_hidden.data() + hidden_offset;
|
||||
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
||||
float* HWY_RESTRICT out_gated =
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
||||
const auto y = Load(df, out + i);
|
||||
const auto x = Load(df, out_mul + i);
|
||||
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
|
||||
}
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<kModelDim, kFFHiddenDim>(
|
||||
weights.linear_w, 0,
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
||||
nullptr, output + pos * kModelDim, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.attention_out.data() + pos * kModelDim,
|
||||
output + pos * kModelDim, kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TConfig, template<typename> typename WeightsT,
|
||||
template<typename> typename LayerT>
|
||||
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||
size_t context_size,
|
||||
const WeightsT<TConfig>& weights,
|
||||
ForwardPass<float, TConfig>& forward,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
static_assert(!TConfig::kAbsolutePE);
|
||||
static_assert(!TConfig::kPostNormScale);
|
||||
static_assert(TConfig::kKVHeads == 1);
|
||||
|
||||
HWY_DASSERT(context_size > 0);
|
||||
HWY_DASSERT(context_size < prompt.size());
|
||||
const size_t num_tokens = prompt.size() - 1;
|
||||
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling,
|
||||
forward.layers[0].input.data(), kModelDim);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
// TODO(szabadka) Implement Griffin layer.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
float* HWY_RESTRICT output = layer + 1 < kLayers ?
|
||||
forward.layers[layer + 1].input.data() :
|
||||
forward.final_layer_output.data();
|
||||
ApplyForwardLayer<TConfig, LayerT>(
|
||||
*weights.GetLayer(layer), forward.layers[layer],
|
||||
num_tokens, output, pool);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
kModelDim, num_tokens, forward.final_norm_output.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<kVocabSize, kModelDim>(
|
||||
weights.embedder_input_embedding, 0,
|
||||
forward.final_norm_output.data() + pos * kModelDim, nullptr,
|
||||
forward.logits.data() + pos * kVocabSize, pool);
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
|
||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
|
||||
kVocabSize, pool);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/forward.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <typename TConfig>
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ByteStorageT& weights_u8,
|
||||
ByteStorageT& forward_u8,
|
||||
hwy::ThreadPool& pool) {
|
||||
const auto& weights =
|
||||
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
auto& forward =
|
||||
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
|
||||
return CrossEntropyLossForwardPass<TConfig, WeightsF, LayerF>(
|
||||
prompt.tokens, prompt.context_size, weights, forward, pool);
|
||||
}
|
||||
|
||||
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
ByteStorageT& forward,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return CrossEntropyLossForwardPass<ConfigGemma2B>(
|
||||
prompt, weights, forward, pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return CrossEntropyLossForwardPass<ConfigGemmaTiny>(
|
||||
prompt, weights, forward, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CrossEntropyLossForwardPassT);
|
||||
|
||||
float CrossEntropyLossForwardPass(
|
||||
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
|
||||
ByteStorageT& forward, hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
||||
model, prompt, weights, forward, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
float CrossEntropyLossForwardPass(
|
||||
const Model& model, const Prompt& prompt, const ByteStorageT& weights,
|
||||
ByteStorageT& forward, hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
|
||||
// y = w * x is N x K matrix in column-major order.
|
||||
template<typename T>
|
||||
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
|
||||
// column-major order and y = w' * x is N x K matrix in column-major order,
|
||||
// where w' is the rearrangement of w into an N x HM matrix.
|
||||
template<typename T>
|
||||
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
|
||||
size_t M, size_t K) {
|
||||
memset(y, 0, N * K * sizeof(y[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t h = 0; h < H; ++h) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
|
||||
constexpr T eps(1e-6);
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
T ss = SquaredL2(x + i * N, N);
|
||||
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
|
||||
for (size_t j = 0; j < N; j++) {
|
||||
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softmax(T* x, size_t N) {
|
||||
T sum = {};
|
||||
auto maxreal = std::real(x[0]);
|
||||
for (size_t i = 1; i < N; ++i) {
|
||||
if (std::real(x[i]) > maxreal) {
|
||||
maxreal = std::real(x[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = std::exp(x[i] - maxreal);
|
||||
sum += x[i];
|
||||
}
|
||||
T scale = T(1.0) / sum;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= scale;
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softmax(T* x, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
Softmax(x + i * N, N);
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softcap(T* x, size_t N) {
|
||||
T cap = 30.0;
|
||||
T inv_cap = T(1.0) / cap;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = cap * std::tanh(x[i] * inv_cap);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const T* x1 = in + i * 2 * N;
|
||||
const T* x2 = x1 + N;
|
||||
T* y = out + i * N;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[j] = x2[j] * Gelu(x1[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
|
||||
T* y, size_t N) {
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
int token = tokens[i];
|
||||
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
|
||||
MulByConstT(scaling, y + i * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MaskedAttention(const T* qkv, T* output, size_t num_tokens,
|
||||
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
const size_t qoffset = pos * (kHeads + 2) * kQKVDim;
|
||||
const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen;
|
||||
const T* q = qkv + qoffset + head * kQKVDim;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
output[aoffset + pos2] = DotT(q, k, kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
|
||||
Softmax(x + offset, pos + 1);
|
||||
memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void MixByAttention(const T* qkv, const T* attention, T* output,
|
||||
size_t num_tokens, size_t kHeads, size_t kQKVDim,
|
||||
size_t kSeqLen) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen];
|
||||
T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim];
|
||||
memset(out, 0, kQKVDim * sizeof(out[0]));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
const T* v = &qkv[v_offset];
|
||||
MulByConstAndAddT(att[pos2], v, out, kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T, typename TConfig>
|
||||
void ApplyLayer(const Layer<T, TConfig>& weights,
|
||||
ForwardLayer<T, TConfig>& activations,
|
||||
size_t num_tokens, T* output) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim));
|
||||
|
||||
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim, num_tokens);
|
||||
|
||||
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
|
||||
activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim,
|
||||
num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * kQKVDim, kQKVDim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
|
||||
}
|
||||
|
||||
MaskedAttention(activations.qkv.data(), activations.att.data(),
|
||||
num_tokens, kHeads, kQKVDim, kSeqLen);
|
||||
|
||||
MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen);
|
||||
|
||||
MixByAttention(activations.qkv.data(), activations.att.data(),
|
||||
activations.att_out.data(), num_tokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
|
||||
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
|
||||
activations.attention_out.data(), kHeads, kModelDim, kQKVDim,
|
||||
num_tokens);
|
||||
|
||||
AddFromT(activations.input.data(), activations.attention_out.data(),
|
||||
num_tokens * kModelDim);
|
||||
|
||||
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens);
|
||||
|
||||
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
|
||||
activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim,
|
||||
num_tokens);
|
||||
|
||||
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
|
||||
kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(),
|
||||
output, kModelDim, kFFHiddenDim, num_tokens);
|
||||
|
||||
AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
|
||||
T loss = {};
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
if (i + 1 < prompt.context_size) {
|
||||
continue; // next token is part of context, don't try to predict it
|
||||
}
|
||||
const int next_token = tokens[i + 1];
|
||||
loss += std::log(x[i * V + next_token]);
|
||||
}
|
||||
T scaling = -1.0 / std::log(2.0);
|
||||
return loss * scaling;
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const Weights<T, TConfig>& weights,
|
||||
ForwardPass<T, TConfig>& forward) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
||||
InputEmbedding(weights.embedder_input_embedding.data(), tokens,
|
||||
kEmbScaling, forward.layers[0].input.data(), kModelDim);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
T* output = layer + 1 < kLayers ?
|
||||
forward.layers[layer + 1].input.data() :
|
||||
forward.final_layer_output.data();
|
||||
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
||||
output);
|
||||
}
|
||||
|
||||
RMSNormT(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
forward.final_norm_output.data(), kModelDim, num_tokens);
|
||||
|
||||
MatMulT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(),
|
||||
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
||||
|
||||
Softcap(forward.logits.data(), num_tokens * kVocabSize);
|
||||
|
||||
memcpy(forward.probs.data(), forward.logits.data(),
|
||||
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
||||
Softmax(forward.probs.data(), kVocabSize, num_tokens);
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
hwy::ThreadPool pool(0);
|
||||
std::mt19937 gen(42);
|
||||
|
||||
Model model_type = Model::GEMMA_TINY;
|
||||
ByteStorageT weights = AllocateWeights(model_type, pool);
|
||||
ByteStorageT grad = AllocateWeights(model_type, pool);
|
||||
ByteStorageT grad_m = AllocateWeights(model_type, pool);
|
||||
ByteStorageT grad_v = AllocateWeights(model_type, pool);
|
||||
ByteStorageT forward = AllocateForwardPass(model_type);
|
||||
ByteStorageT backward = AllocateForwardPass(model_type);
|
||||
ByteStorageT inference = AllocateInferenceState(model_type);
|
||||
auto kv_cache = CreateKVCache(model_type);
|
||||
size_t max_tokens = 32;
|
||||
size_t max_generated_tokens = 16;
|
||||
float temperature = 1.0f;
|
||||
int verbosity = 0;
|
||||
const auto accept_token = [](int) { return true; };
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
auto stream_token = [&reply](int token, float) {
|
||||
reply.push_back(token);
|
||||
return token != ReverseSequenceSampler::kEndToken;
|
||||
};
|
||||
RuntimeConfig runtime = {
|
||||
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
|
||||
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
GenerateGemma(model_type, weights, inference, runtime, prompt, 0,
|
||||
kv_cache, pool, timing_info);
|
||||
return reply;
|
||||
};
|
||||
|
||||
auto verify = [&](const Prompt& prompt) {
|
||||
auto context = prompt.context();
|
||||
std::vector<int> reply = generate(context);
|
||||
bool ok = true;
|
||||
for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) {
|
||||
if (i >= reply.size() || reply[i] != prompt.tokens[i]) {
|
||||
ok = false;
|
||||
}
|
||||
}
|
||||
return ok;
|
||||
};
|
||||
|
||||
RandInitWeights(model_type, weights, pool, gen);
|
||||
ZeroInitWeights(model_type, grad_m, pool);
|
||||
ZeroInitWeights(model_type, grad_v, pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
LogWeightStats(model_type, weights);
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
float learning_rate = 0.0005f;
|
||||
|
||||
ReverseSequenceSampler training_task({
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||
size_t steps = 0;
|
||||
float prev_loss = std::numeric_limits<float>::max();
|
||||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
ZeroInitWeights(model_type, grad, pool);
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
Prompt prompt = training_task.Sample(sgen);
|
||||
total_loss += CrossEntropyLossForwardPass(
|
||||
model_type, prompt, weights, forward, pool);
|
||||
CrossEntropyLossBackwardPass(
|
||||
model_type, prompt, weights, forward, grad, backward, pool);
|
||||
num_ok += verify(prompt) ? 1 : 0;
|
||||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
const float scale = -learning_rate / kBatchSize;
|
||||
UpdateWeights(model_type, grad, scale, weights, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
printf("Batch gradient:\n");
|
||||
LogWeightStats(model_type, grad);
|
||||
}
|
||||
if (total_loss < 0.5f) {
|
||||
break;
|
||||
}
|
||||
prev_loss = total_loss;
|
||||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
LogWeightStats(model_type, weights);
|
||||
EXPECT_LT(steps, 3000);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/optimizer.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
class WeightInitializer {
|
||||
public:
|
||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||
|
||||
template <size_t N>
|
||||
void operator()(const char* name, std::array<float, N>& tensor) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
tensor[i] = dist_(gen_);
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::normal_distribution<float> dist_;
|
||||
std::mt19937& gen_;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) {
|
||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
// TODO(szabadka) Use the same weight initialization method as in the python
|
||||
// version.
|
||||
WeightInitializer init(gen);
|
||||
ForEachTensor1<float, TConfig>(init, weights);
|
||||
}
|
||||
|
||||
class WeightUpdater {
|
||||
public:
|
||||
explicit WeightUpdater(float lr) : lr_(lr) {}
|
||||
|
||||
template <size_t kCapacity>
|
||||
void operator()(const char* name, const std::array<float, kCapacity>& grad,
|
||||
std::array<float, kCapacity>& weights) {
|
||||
for (size_t i = 0; i < kCapacity; ++i) {
|
||||
weights[i] += lr_ * grad[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float lr_;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
void UpdateWeights(const ByteStorageT& grad_u8, float scale,
|
||||
ByteStorageT& weights_u8, hwy::ThreadPool& pool) {
|
||||
const auto& grad =
|
||||
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
|
||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
WeightUpdater updater(scale);
|
||||
ForEachTensor2<float, TConfig>(updater, grad, weights);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RandInitWeights(Model model, ByteStorageT& weights_u8,
|
||||
hwy::ThreadPool& pool, std::mt19937& gen) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
RandInitWeights<ConfigGemma2B>(weights_u8, pool, gen);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
RandInitWeights<ConfigGemma7B>(weights_u8, pool, gen);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
RandInitWeights<ConfigGriffin2B>(weights_u8, pool, gen);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
RandInitWeights<ConfigGemmaTiny>(weights_u8, pool, gen);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
||||
ByteStorageT& weights, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
UpdateWeights<ConfigGemma2B>(grad, scale, weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
UpdateWeights<ConfigGemma7B>(grad, scale, weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
UpdateWeights<ConfigGriffin2B>(grad, scale, weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
UpdateWeights<ConfigGemmaTiny>(grad, scale, weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen);
|
||||
|
||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
||||
ByteStorageT& weights, hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct Prompt {
|
||||
std::vector<int> tokens;
|
||||
size_t context_size;
|
||||
std::vector<int> context() const {
|
||||
return std::vector<int>(tokens.begin(), tokens.begin() + context_size);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class PromptSampler {
|
||||
public:
|
||||
virtual Prompt Sample(std::mt19937& gen) = 0;
|
||||
|
||||
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
|
||||
std::vector<Prompt> batch;
|
||||
batch.reserve(batch_size);
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch.emplace_back(Sample(gen));
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
};
|
||||
|
||||
class ReverseSequenceSampler : public PromptSampler {
|
||||
public:
|
||||
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
|
||||
: token_dist_(0, 9) {
|
||||
for (int i = 0; i < length_histo.size(); ++i) {
|
||||
const int count = length_histo[i];
|
||||
for (int j = 0; j < count; ++j) {
|
||||
length_lut_.push_back(i + 1);
|
||||
}
|
||||
}
|
||||
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
|
||||
}
|
||||
|
||||
static constexpr int kReverseToken = 10;
|
||||
static constexpr int kEndToken = 11;
|
||||
|
||||
Prompt Sample(std::mt19937& gen) override {
|
||||
Prompt prompt;
|
||||
int len = length_lut_[length_dist_(gen)];
|
||||
prompt.tokens.resize(2 * len + 2);
|
||||
prompt.tokens[len] = kReverseToken;
|
||||
prompt.tokens[2 * len + 1] = kEndToken;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
|
||||
}
|
||||
prompt.context_size = len + 1;
|
||||
return prompt;
|
||||
}
|
||||
|
||||
static void LogPrompt(const Prompt& prompt) {
|
||||
static const char* kVocab[] = {
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
|
||||
};
|
||||
for (int token : prompt.tokens) printf("%s", kVocab[token]);
|
||||
printf(" [context_size: %zu]\n", prompt.context_size);
|
||||
}
|
||||
|
||||
private:
|
||||
std::uniform_int_distribution<> token_dist_;
|
||||
std::uniform_int_distribution<> length_dist_;
|
||||
std::vector<int> length_lut_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <random>
|
||||
|
||||
#include "gemma/weights.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template<typename T, size_t kLen>
|
||||
void RandInit(std::array<T, kLen>& x, T stddev, std::mt19937& gen) {
|
||||
std::normal_distribution<T> dist(0.0, stddev);
|
||||
for (size_t i = 0; i < kLen; ++i) {
|
||||
x[i] = dist(gen);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void RandInit(Layer<T, TConfig>& w, T stddev, std::mt19937& gen) {
|
||||
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
||||
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
||||
RandInit(w.qkv_einsum_w, stddev, gen);
|
||||
RandInit(w.pre_ffw_norm_scale, stddev, gen);
|
||||
RandInit(w.gating_einsum_w, stddev, gen);
|
||||
RandInit(w.linear_w, stddev, gen);
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void RandInit(Weights<T, TConfig>& w, T stddev, std::mt19937& gen) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
RandInit(w.embedder_input_embedding, stddev, gen);
|
||||
RandInit(w.final_norm_scale, stddev, gen);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
RandInit(*w.GetLayer(i), stddev, gen);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, size_t kLen>
|
||||
void Complexify(const std::array<T, kLen>& x,
|
||||
std::array<std::complex<U>, kLen>& c_x) {
|
||||
for (size_t i = 0; i < kLen; ++i) {
|
||||
c_x[i] = std::complex<U>(x[i], 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename U, typename TConfig>
|
||||
void Complexify(const Layer<T, TConfig>& w,
|
||||
Layer<std::complex<U>, TConfig>& c_w) {
|
||||
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
|
||||
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
|
||||
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
|
||||
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
|
||||
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
|
||||
Complexify(w.linear_w, c_w.linear_w);
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename TConfig>
|
||||
void Complexify(const Weights<T, TConfig>& w,
|
||||
Weights<std::complex<U>, TConfig>& c_w) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
|
||||
Complexify(w.final_norm_scale, c_w.final_norm_scale);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, size_t N>
|
||||
void TestNear(const std::array<T, N>& actual, const std::array<U, N>& expected,
|
||||
double max_abs_err, double max_rel_err, int line) {
|
||||
double sum0 = 0;
|
||||
double sum1 = 0;
|
||||
double sum01 = 0;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum0 += actual[i] * actual[i];
|
||||
sum1 += expected[i] * expected[i];
|
||||
sum01 += actual[i] * expected[i];
|
||||
ASSERT_NEAR(actual[i], expected[i],
|
||||
std::max(max_abs_err, std::abs(expected[i]) * max_rel_err))
|
||||
<< "line: " << line << " dim=" << N << " i=" << i;
|
||||
}
|
||||
if (sum0 > 1e-40) {
|
||||
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
||||
ASSERT_NEAR(norm_dot, 1.0, 1e-7)
|
||||
<< "line: " << line << " sum0: " << sum0 << " sum1: " << sum1
|
||||
<< " sum01: " << sum01;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute gradient with the finite difference method in the complex plane.
|
||||
// If f : R->R is the tested function and F : C->C is its extension on the
|
||||
// complex plane so that F is complex differentiable in x, then
|
||||
//
|
||||
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
|
||||
//
|
||||
// which means that
|
||||
//
|
||||
// F'(x) ~= Imag(F(x + ih)) / h
|
||||
//
|
||||
// This method is more numerically stable than the real-valued finite difference
|
||||
// method since we don't need to subtract floating point numbers that are near
|
||||
// to each other.
|
||||
template<typename T, typename U, size_t N, typename FUNC>
|
||||
void TestGradient(const std::array<T, N>& grad,
|
||||
std::array<std::complex<U>, N>& x, FUNC func,
|
||||
U step, T max_abs_err, T max_rel_err, int line) {
|
||||
std::array<T, N> exp_grad;
|
||||
const U inv_step = 1.0 / step;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
const U x0 = std::real(x[i]);
|
||||
const std::complex<U> x1 = std::complex<U>(x0, step);
|
||||
x[i] = x1;
|
||||
const std::complex<U> f1 = func();
|
||||
exp_grad [i] = std::imag(f1) * inv_step;
|
||||
x[i] = x0;
|
||||
}
|
||||
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
|
||||
}
|
||||
|
||||
template<size_t N, typename FUNC>
|
||||
void TestGradient(const std::array<float, N>& grad,
|
||||
std::array<std::complex<float>, N>& x, FUNC func,
|
||||
float max_abs_err, float max_rel_error, int line) {
|
||||
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line);
|
||||
}
|
||||
|
||||
template<size_t N, typename FUNC>
|
||||
void TestGradient(const std::array<float, N>& grad,
|
||||
std::array<std::complex<double>, N>& x, FUNC func,
|
||||
float max_abs_err, float max_rel_error, int line) {
|
||||
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
|
||||
}
|
||||
|
||||
template<size_t N, typename FUNC>
|
||||
void TestGradient(const std::array<double, N>& grad,
|
||||
std::array<std::complex<double>, N>& x, FUNC func,
|
||||
double max_abs_err, double max_rel_error, int line) {
|
||||
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename TConfig, typename FUNC>
|
||||
void TestGradient(const Layer<T, TConfig>& grad,
|
||||
Layer<std::complex<U>, TConfig>& c_weights,
|
||||
FUNC func, T max_err) {
|
||||
TestGradient(grad.pre_attention_norm_scale,
|
||||
c_weights.pre_attention_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.linear_w, c_weights.linear_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename TConfig, typename FUNC>
|
||||
void TestGradient(const Weights<T, TConfig>& grad,
|
||||
Weights<std::complex<U>, TConfig>& c_weights,
|
||||
FUNC func, T max_err) {
|
||||
TestGradient(grad.embedder_input_embedding,
|
||||
c_weights.embedder_input_embedding,
|
||||
func, 2 * max_err, max_err, __LINE__);
|
||||
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
|
|
@ -484,9 +484,10 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
|||
const hn::ScalableTag<OutT> d;
|
||||
|
||||
const size_t ofs = idx_batch * kBatch;
|
||||
const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
|
||||
const size_t batch =
|
||||
idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
|
||||
Traits::Decompress(d, compressed.size(), compressed.data(),
|
||||
compressed_ofs + ofs, out + ofs, num);
|
||||
compressed_ofs + ofs, out + ofs, batch);
|
||||
});
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
|
|
@ -495,6 +496,16 @@ HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
|||
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
|
||||
}
|
||||
|
||||
// Returns dot product with `vec_aligned` of length `num`.
|
||||
template <bool kVecEO, class DF, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(DF df, const std::array<float, kCapacity>& w, size_t ofs,
|
||||
const VecT* x, size_t num) {
|
||||
HWY_DASSERT(ofs + num <= kCapacity);
|
||||
HWY_DASSERT(hn::IsAligned(df, x));
|
||||
using Traits = CompressTraits<float>;
|
||||
return Traits::Dot(df, w.size(), w.data(), ofs, x, num);
|
||||
}
|
||||
|
||||
// Returns dot product with `vec_aligned` of length `num`.
|
||||
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/activations.h"
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
ByteStorageT AllocateForwardPass(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return ForwardPass<float, ConfigGemma2B>::Allocate();
|
||||
case Model::GEMMA_7B:
|
||||
return ForwardPass<float, ConfigGemma7B>::Allocate();
|
||||
case Model::GRIFFIN_2B:
|
||||
return ForwardPass<float, ConfigGriffin2B>::Allocate();
|
||||
case Model::GEMMA_TINY:
|
||||
return ForwardPass<float, ConfigGemmaTiny>::Allocate();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
struct ForwardLayer {
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
std::array<T, kSeqLen * kModelDim> input;
|
||||
std::array<T, kSeqLen * kModelDim> pre_att_rms_out;
|
||||
std::array<T, kSeqLen * (kHeads + 2) * kQKVDim> qkv;
|
||||
std::array<T, kSeqLen * kHeads * kSeqLen> att;
|
||||
std::array<T, kSeqLen * kHeads * kQKVDim> att_out;
|
||||
std::array<T, kSeqLen * kModelDim> att_post1;
|
||||
std::array<T, kSeqLen * kModelDim> attention_out;
|
||||
std::array<T, kSeqLen * kModelDim> bf_pre_ffw_rms_out;
|
||||
std::array<T, kSeqLen * kFFHiddenDim * 2> ffw_hidden;
|
||||
std::array<T, kSeqLen * kFFHiddenDim> ffw_hidden_gated;
|
||||
};
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
struct ForwardPass {
|
||||
ForwardPass() {} // prevents placement-new calling memset
|
||||
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
|
||||
std::array<ForwardLayer<T, TConfig>, kLayers> layers;
|
||||
std::array<T, kSeqLen * kModelDim> final_layer_output;
|
||||
std::array<T, kSeqLen * kModelDim> final_norm_output;
|
||||
std::array<T, kSeqLen * kVocabSize> logits;
|
||||
std::array<T, kSeqLen * kVocabSize> probs;
|
||||
|
||||
static ByteStorageT Allocate() {
|
||||
return hwy::AllocateAligned<uint8_t>(sizeof(ForwardPass<T, TConfig>));
|
||||
}
|
||||
};
|
||||
|
||||
// Owns activations and undoes the type erasure of AllocateAligned.
|
||||
template<typename T, typename TConfig>
|
||||
class ActivationsWrapper {
|
||||
using WrappedT = ForwardPass<T, TConfig>;
|
||||
|
||||
public:
|
||||
ActivationsWrapper()
|
||||
: data_(WrappedT::Allocate()),
|
||||
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
|
||||
|
||||
const WrappedT& get() const { return activations_; }
|
||||
WrappedT& get() { return activations_; }
|
||||
|
||||
private:
|
||||
ByteStorageT data_;
|
||||
WrappedT& activations_;
|
||||
};
|
||||
|
||||
ByteStorageT AllocateForwardPass(Model model);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
// are both constexpr
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
|
@ -142,6 +142,38 @@ struct ConfigGemma2B {
|
|||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
struct ConfigGemmaTiny {
|
||||
static constexpr int kSeqLen = 32;
|
||||
static constexpr int kVocabSize = 16;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
FixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
static constexpr int kGriffinLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
kLayers);
|
||||
static constexpr int kModelDim = 64;
|
||||
static constexpr int kFFHiddenDim = 128;
|
||||
static constexpr int kHeads = 4;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 16; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
struct ConfigGriffin2B {
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
|
|
|
|||
312
gemma/gemma.cc
312
gemma/gemma.cc
|
|
@ -22,6 +22,7 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Must come after foreach_target.h to avoid redefinition errors.
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/common-inl.h"
|
||||
#include "gemma/ops.h"
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -53,6 +54,7 @@
|
|||
#include "compression/io.h" // Path
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -72,63 +74,6 @@ constexpr bool kShowTokenization = false;
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
template <class TConfig>
|
||||
struct Layer {
|
||||
Layer() = default;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
||||
static constexpr size_t kQKVEinsumWSize =
|
||||
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||
// 2x for (gelu gating vector, gated vector)
|
||||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||
|
||||
template <class T, size_t N>
|
||||
using ArrayT = std::array<T, N>;
|
||||
|
||||
union {
|
||||
struct {
|
||||
ArrayT<float, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
ArrayT<float, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<float, kAOBiasDim> attention_output_biases;
|
||||
};
|
||||
|
||||
struct {
|
||||
ArrayT<float, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
ArrayT<float, kGriffinDim> linear_x_biases;
|
||||
ArrayT<float, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||
ArrayT<float, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||
ArrayT<float, kConv1dWidth * kGriffinDim> conv_w;
|
||||
ArrayT<float, kGriffinDim> conv_biases;
|
||||
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||
ArrayT<float, kGriffinDim> a;
|
||||
} griffin;
|
||||
};
|
||||
|
||||
ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
|
||||
ArrayT<float, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<float, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<float, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
|
||||
ArrayT<float, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
};
|
||||
|
||||
float ScaleWeights(float* data, size_t len) {
|
||||
float maxabs = 0.0;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
|
|
@ -146,41 +91,6 @@ float ScaleWeights(float* data, size_t len) {
|
|||
return scale;
|
||||
}
|
||||
|
||||
// Array instead of single large allocation for parallel mem init. Split out of
|
||||
// Weights so that only these pointers are initialized.
|
||||
template <class TConfig>
|
||||
struct LayerPointers {
|
||||
explicit LayerPointers(hwy::ThreadPool& pool) {
|
||||
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
||||
this->layers[task] = hwy::AllocateAligned<Layer<TConfig>>(1);
|
||||
});
|
||||
}
|
||||
|
||||
using TLayer = Layer<TConfig>;
|
||||
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct Weights {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
|
||||
embedder_input_embedding;
|
||||
|
||||
std::array<float, TConfig::kModelDim> final_norm_scale;
|
||||
|
||||
LayerPointers<TConfig> layer_ptrs;
|
||||
|
||||
std::array<float, TConfig::kNumTensorScales> scales;
|
||||
|
||||
const Layer<TConfig>* GetLayer(size_t layer) const {
|
||||
return layer_ptrs.layers[layer].get();
|
||||
}
|
||||
Layer<TConfig>* GetLayer(size_t layer) {
|
||||
return layer_ptrs.layers[layer].get();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
||||
const Path& checkpoint, hwy::ThreadPool& pool,
|
||||
|
|
@ -191,11 +101,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
checkpoint.path.c_str());
|
||||
}
|
||||
|
||||
using TWeights = Weights<TConfig>;
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
|
||||
hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
|
||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||
new (&weights->layer_ptrs) LayerPointers<TConfig>(pool);
|
||||
ByteStorageT weights_u8 = AllocateWeights<float, TConfig>(pool);
|
||||
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
|
||||
size_t scale_pos = 0;
|
||||
FILE* fptr;
|
||||
|
|
@ -228,7 +135,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
sizeof(weights->final_norm_scale));
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
Layer<TConfig>* layer_view = weights->GetLayer(layer);
|
||||
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
|
||||
|
||||
#define READ_WEIGHTS(name) \
|
||||
do { \
|
||||
|
|
@ -305,7 +212,7 @@ template <class TConfig>
|
|||
struct CompressedLayer {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
using TLayer = gcpp::Layer<TConfig>;
|
||||
using TLayer = gcpp::LayerF<TConfig>;
|
||||
using WeightT = typename TConfig::WeightT;
|
||||
|
||||
static constexpr size_t kHeads = TLayer::kHeads;
|
||||
|
|
@ -399,13 +306,13 @@ struct CompressedWeights {
|
|||
|
||||
template <class TConfig>
|
||||
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
||||
Weights<TConfig>>;
|
||||
WeightsF<TConfig>>;
|
||||
|
||||
// Aligned.
|
||||
template <class TConfig, size_t TBatchSize>
|
||||
struct Activations {
|
||||
static constexpr size_t kBatchSize = TBatchSize;
|
||||
using LayerConfig = Layer<TConfig>;
|
||||
using LayerConfig = LayerF<TConfig>;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
|
|
@ -446,6 +353,16 @@ struct Activations {
|
|||
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
|
||||
};
|
||||
|
||||
template<typename TConfig>
|
||||
struct InferenceState {
|
||||
Activations<TConfig, kPrefillBatchSize> prefill;
|
||||
HWY_ALIGN Activations<TConfig, 1> state;
|
||||
|
||||
static ByteStorageT Allocate() {
|
||||
return hwy::AllocateAligned<uint8_t>(sizeof(InferenceState<TConfig>));
|
||||
}
|
||||
};
|
||||
|
||||
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
|
||||
// define an abstract base class.
|
||||
struct GemmaInterface {
|
||||
|
|
@ -484,6 +401,8 @@ KVCache CreateKVCache(Model type) {
|
|||
return CreateKVCacheT<ConfigGemma7B>();
|
||||
case Model::GRIFFIN_2B:
|
||||
return CreateKVCacheT<ConfigGriffin2B>();
|
||||
case Model::GEMMA_TINY:
|
||||
return CreateKVCacheT<ConfigGemmaTiny>();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||
}
|
||||
|
|
@ -526,8 +445,8 @@ void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
|
|||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||
}
|
||||
template <class Config>
|
||||
void DeleteLayersPtrs(Weights<Config>* weights) {
|
||||
weights->layer_ptrs.~LayerPointers<Config>();
|
||||
void DeleteLayersPtrs(WeightsF<Config>* weights) {
|
||||
weights->layer_ptrs.~LayerPointers<float, Config>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
@ -866,8 +785,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
// Same matrix, first and second half of rows. Could fuse into one MatVec.
|
||||
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
|
||||
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd,
|
||||
out_mul, pool);
|
||||
TConfig::kFFBiases ?
|
||||
layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr,
|
||||
even_odd, out_mul, pool);
|
||||
// Gate, will go through the nonlinearity.
|
||||
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, 0, vec,
|
||||
|
|
@ -892,21 +812,6 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
}
|
||||
}
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
// are both constexpr
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const WeightArrayT& weights,
|
||||
|
|
@ -1076,20 +981,15 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
||||
template <class TConfig, template<typename> typename WeightsType>
|
||||
void GenerateImpl(const WeightsType<TConfig>& weights,
|
||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations,
|
||||
Activations<TConfig, 1>& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
||||
*gemma.prefill.get();
|
||||
|
||||
const WeightsT<TConfig>& weights =
|
||||
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||
|
||||
size_t prompt_size = prompt.size();
|
||||
size_t max_tokens = runtime_config.max_tokens;
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
|
|
@ -1167,7 +1067,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
|||
activations.logits.data(), kVocabSize, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
||||
token = EOS_ID;
|
||||
token = runtime_config.eos_id;
|
||||
}
|
||||
if (generate_pos == 0) {
|
||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
||||
|
|
@ -1177,10 +1077,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
|||
// process the tokens of the prompt one at a time.
|
||||
token = prompt.at(pos_offset + 1);
|
||||
if (!runtime_config.stream_token(token, 0)) {
|
||||
token = EOS_ID;
|
||||
token = runtime_config.eos_id;
|
||||
}
|
||||
}
|
||||
if (token == EOS_ID) {
|
||||
if (token == runtime_config.eos_id) {
|
||||
if (runtime_config.verbosity >= 2) {
|
||||
const double gen_end = hwy::platform::Now();
|
||||
timing_info.gen_tok_sec =
|
||||
|
|
@ -1192,6 +1092,35 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
const WeightsT<TConfig>& weights =
|
||||
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||
GenerateImpl<TConfig, WeightsT>(
|
||||
weights, *gemma.prefill.get(), *gemma.state.get(), runtime_config, prompt,
|
||||
pos, kv_cache, pool, timing_info, layers_output);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateGemma(const ByteStorageT& weights_u8,
|
||||
ByteStorageT& inference_state_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
const WeightsF<TConfig>& weights =
|
||||
*reinterpret_cast<const WeightsF<TConfig>*>(weights_u8.get());
|
||||
InferenceState<TConfig>& inference_state =
|
||||
*reinterpret_cast<InferenceState<TConfig>*>(inference_state_u8.get());
|
||||
GenerateImpl<TConfig, WeightsF>(
|
||||
weights, inference_state.prefill, inference_state.state, runtime_config,
|
||||
prompt, pos, kv_cache, pool, timing_info, layers_output);
|
||||
}
|
||||
|
||||
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()
|
||||
|
||||
template <class TConfig>
|
||||
|
|
@ -1263,7 +1192,7 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
//
|
||||
// This avoids repeating the list of tensors between loading and compressing.
|
||||
template <class TConfig, class Func>
|
||||
void ForEachTensor(const Weights<TConfig>* weights,
|
||||
void ForEachTensor(const WeightsF<TConfig>* weights,
|
||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||
func("c_embedding",
|
||||
weights ? weights->embedder_input_embedding.data() : nullptr,
|
||||
|
|
@ -1275,7 +1204,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
||||
const LayerF<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
||||
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
||||
|
||||
#define CALL_FUNC(name, member) \
|
||||
|
|
@ -1386,34 +1315,17 @@ void CompressWeights(const Path& weights_path,
|
|||
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
|
||||
const hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
|
||||
LoadWeights<TConfig>(weights_path, pool, scale_for_compression);
|
||||
Weights<TConfig>* weights =
|
||||
reinterpret_cast<Weights<TConfig>*>(weights_u8.get());
|
||||
WeightsF<TConfig>* weights =
|
||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights, *c_weights, compressor);
|
||||
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||
compressor.WriteAll(pool, compressed_weights_path);
|
||||
|
||||
weights->layer_ptrs.~LayerPointers<TConfig>();
|
||||
weights->layer_ptrs.~LayerPointers<float, TConfig>();
|
||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||
}
|
||||
|
||||
void CompressWeightsT(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -1421,8 +1333,6 @@ HWY_AFTER_NAMESPACE();
|
|||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CompressWeightsT);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size) {
|
||||
KVCache kv_cache = {};
|
||||
|
|
@ -1528,10 +1438,87 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
|||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void GenerateGemma(Model model, const ByteStorageT& weights,
|
||||
ByteStorageT& inference_state,
|
||||
RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma2B>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemma7B>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGriffin2B>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma<ConfigGemmaTiny>)(
|
||||
weights, inference_state, runtime_config, prompt, start_pos, kv_cache,
|
||||
pool, timing_info, /*layers_output=*/nullptr);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return LoadWeights<ConfigGemma2B>(weights, pool);
|
||||
case Model::GEMMA_7B:
|
||||
return LoadWeights<ConfigGemma7B>(weights, pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LoadWeights<ConfigGriffin2B>(weights, pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return LoadWeights<ConfigGemmaTiny>(weights, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
ByteStorageT AllocateInferenceState(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return InferenceState<ConfigGemma2B>::Allocate();
|
||||
case Model::GEMMA_7B:
|
||||
return InferenceState<ConfigGemma7B>::Allocate();
|
||||
case Model::GRIFFIN_2B:
|
||||
return InferenceState<ConfigGriffin2B>::Allocate();
|
||||
case Model::GEMMA_TINY:
|
||||
return InferenceState<ConfigGemmaTiny>::Allocate();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool) {
|
||||
HWY_DYNAMIC_DISPATCH(CompressWeightsT)
|
||||
(model, weights, compressed_weights, pool);
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma2B>)(
|
||||
weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGemma7B>)(
|
||||
weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<ConfigGriffin2B>)(
|
||||
weights, compressed_weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
|
|
@ -1546,13 +1533,16 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
|||
|
||||
namespace {
|
||||
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
|
||||
"2b-it", "7b-it", "gr2b-it"};
|
||||
"2b-it", "7b-it", "gr2b-it",
|
||||
"tiny"};
|
||||
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
|
||||
Model::GRIFFIN_2B, Model::GEMMA_2B,
|
||||
Model::GEMMA_7B, Model::GRIFFIN_2B};
|
||||
Model::GEMMA_7B, Model::GRIFFIN_2B,
|
||||
Model::GEMMA_TINY};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
|
||||
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT};
|
||||
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
|
||||
ModelTraining::GEMMA_IT};
|
||||
} // namespace
|
||||
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
|
@ -51,8 +52,6 @@ struct KVCache {
|
|||
rglru_cache; // kModelDim * kGriffinLayers
|
||||
};
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
|
|
@ -68,6 +67,8 @@ using StreamFunc = std::function<bool(int, float)>;
|
|||
// want to generate and True for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
|
@ -76,6 +77,7 @@ struct RuntimeConfig {
|
|||
std::mt19937* gen;
|
||||
const StreamFunc& stream_token;
|
||||
const AcceptFunc& accept_token;
|
||||
int eos_id = EOS_ID;
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
|
@ -118,6 +120,18 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
|||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
void GenerateGemma(Model model, const ByteStorageT& weights,
|
||||
ByteStorageT& inference_state,
|
||||
RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info);
|
||||
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
ByteStorageT AllocateInferenceState(Model model);
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||
|
||||
|
|
@ -125,8 +139,6 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
|||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class GemmaTest : public ::testing::Test {
|
|||
gcpp::Gemma model;
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, Geography) {
|
||||
TEST_F(GemmaTest, DISABLED_Geography) {
|
||||
static const char* kQA[][2] = {
|
||||
{"What is the capital of Hungary?", "Budapest"},
|
||||
{"How many states does the US have?", "50"},
|
||||
|
|
@ -108,7 +108,7 @@ TEST_F(GemmaTest, Geography) {
|
|||
TestQuestions(kQA, kNum);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, History) {
|
||||
TEST_F(GemmaTest, DISABLED_History) {
|
||||
static const char* kQA[][2] = {
|
||||
{"When was the Battle of Hastings?", "1066"},
|
||||
{"Who fought at the Battle of Marathon?", "Greek"},
|
||||
|
|
@ -117,7 +117,7 @@ TEST_F(GemmaTest, History) {
|
|||
TestQuestions(kQA, kNum);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Arithmetic) {
|
||||
TEST_F(GemmaTest, DISABLED_Arithmetic) {
|
||||
static const char* kQA[][2] = {
|
||||
{"what is 13 + 14?", "27"},
|
||||
{"what is 7 * 8", "56"},
|
||||
|
|
@ -280,7 +280,7 @@ static const char kDeclaration[] = {
|
|||
"reliance on the protection of divine Providence, we mutually pledge to "
|
||||
"each other our Lives, our Fortunes and our sacred Honor.\n"};
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
TEST_F(GemmaTest, DISABLED_CrossEntropySmall) {
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = GemmaCrossEntropy(kSmall);
|
||||
|
|
@ -288,19 +288,19 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
EXPECT_LT(entropy, 1.6f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) {
|
||||
float entropy = GemmaCrossEntropy(kJingleBells);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 2.3f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) {
|
||||
float entropy = GemmaCrossEntropy(kGettysburg);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 1.2f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyDeclaration) {
|
||||
TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) {
|
||||
float entropy = GemmaCrossEntropy(kDeclaration);
|
||||
std::cout << "per-byte entropy: " << entropy << "\n";
|
||||
EXPECT_LT(entropy, 1.0f);
|
||||
|
|
|
|||
|
|
@ -874,10 +874,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
|
||||
m*theta_i. However in the Gemma implementation we choose to rotate
|
||||
the pairs of dimensions v_{i} and v_{i + d//2} instead.
|
||||
|
||||
pos parameter is deliberately an int because in the backward pass we
|
||||
call this with negative values (for the VJP calculation we need the transpose
|
||||
of this rotation matrix which is simply the same matrix with -pos parameter)
|
||||
*/
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
||||
size_t dim_qkv, size_t pos) {
|
||||
size_t dim_qkv, int pos) {
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
|
|
@ -898,7 +902,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
|||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
||||
float* HWY_RESTRICT x,
|
||||
size_t dim_qkv,
|
||||
size_t pos) {
|
||||
int pos) {
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/weights.h"
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/stats.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return AllocateWeights<float, ConfigGemma2B>(pool);
|
||||
case Model::GEMMA_7B:
|
||||
return AllocateWeights<float, ConfigGemma7B>(pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return AllocateWeights<float, ConfigGriffin2B>(pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return AllocateWeights<float, ConfigGemmaTiny>(pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename TConfig>
|
||||
void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) {
|
||||
ZeroInit<float, TConfig>(
|
||||
*reinterpret_cast<Weights<float, TConfig>*>(weights.get()));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ZeroInitWeights(Model model, ByteStorageT& weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
ZeroInitWeightsT<ConfigGemma2B>(weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
ZeroInitWeightsT<ConfigGemma7B>(weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
ZeroInitWeightsT<ConfigGriffin2B>(weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
ZeroInitWeightsT<ConfigGemmaTiny>(weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void LogVec(const char* name, const float* data, size_t len) {
|
||||
hwy::Stats stats;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
stats.Notify(data[i]);
|
||||
}
|
||||
printf("%-20s %12zu %13.10f %8.5f %13.10f\n",
|
||||
name, len, stats.Min(), stats.Mean(), stats.Max());
|
||||
}
|
||||
|
||||
class WeightLogger {
|
||||
public:
|
||||
template <size_t N>
|
||||
void operator()(const char* name, const std::array<float, N>& tensor) {
|
||||
LogVec(name, tensor.data(), N);
|
||||
total_weights += N;
|
||||
}
|
||||
size_t total_weights = 0;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
void LogWeightStats(const ByteStorageT& weights_u8) {
|
||||
const auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
WeightLogger logger;
|
||||
ForEachTensor1<float, TConfig>(logger, weights);
|
||||
printf("%-20s %12zu\n", "Total", logger.total_weights);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return LogWeightStats<ConfigGemma2B>(weights);
|
||||
case Model::GEMMA_7B:
|
||||
return LogWeightStats<ConfigGemma7B>(weights);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LogWeightStats<ConfigGriffin2B>(weights);
|
||||
case Model::GEMMA_TINY:
|
||||
return LogWeightStats<ConfigGemmaTiny>(weights);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T, class TConfig>
|
||||
struct Layer {
|
||||
Layer() {}
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
||||
static constexpr size_t kQKVEinsumWSize =
|
||||
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||
// 2x for (gelu gating vector, gated vector)
|
||||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||
|
||||
union {
|
||||
struct {
|
||||
std::array<T, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
std::array<T, kQKVEinsumWSize> qkv_einsum_w;
|
||||
std::array<T, kAOBiasDim> attention_output_biases;
|
||||
};
|
||||
|
||||
struct {
|
||||
std::array<T, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
std::array<T, kGriffinDim> linear_x_biases;
|
||||
std::array<T, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
std::array<T, kGriffinDim> linear_y_biases;
|
||||
std::array<T, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
std::array<T, kGriffinDim> linear_out_biases;
|
||||
std::array<T, kConv1dWidth * kGriffinDim> conv_w;
|
||||
std::array<T, kGriffinDim> conv_biases;
|
||||
std::array<T, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
std::array<T, kGriffinDim * 2> gate_biases;
|
||||
std::array<T, kGriffinDim> a;
|
||||
} griffin;
|
||||
};
|
||||
|
||||
std::array<T, kGatingEinsumWSize> gating_einsum_w;
|
||||
std::array<T, kModelDim * kFFHiddenDim> linear_w;
|
||||
std::array<T, kModelDim> pre_attention_norm_scale;
|
||||
std::array<T, kModelDim> pre_ffw_norm_scale;
|
||||
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
|
||||
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
|
||||
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
using LayerF = Layer<float, TConfig>;
|
||||
|
||||
// Array instead of single large allocation for parallel mem init. Split out of
|
||||
// Weights so that only these pointers are initialized.
|
||||
template <typename T, class TConfig>
|
||||
struct LayerPointers {
|
||||
explicit LayerPointers(hwy::ThreadPool& pool) {
|
||||
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
||||
this->layers[task] = hwy::AllocateAligned<Layer<T, TConfig>>(1);
|
||||
});
|
||||
}
|
||||
|
||||
using TLayer = Layer<T, TConfig>;
|
||||
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
|
||||
};
|
||||
|
||||
template <typename T, class TConfig>
|
||||
struct Weights {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
std::array<T, TConfig::kVocabSize * TConfig::kModelDim>
|
||||
embedder_input_embedding;
|
||||
|
||||
std::array<T, TConfig::kModelDim> final_norm_scale;
|
||||
|
||||
LayerPointers<T, TConfig> layer_ptrs;
|
||||
|
||||
std::array<T, TConfig::kNumTensorScales> scales;
|
||||
|
||||
const Layer<T, TConfig>* GetLayer(size_t layer) const {
|
||||
return layer_ptrs.layers[layer].get();
|
||||
}
|
||||
Layer<T, TConfig>* GetLayer(size_t layer) {
|
||||
return layer_ptrs.layers[layer].get();
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
using WeightsF = Weights<float, TConfig>;
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
|
||||
using TWeights = Weights<T, TConfig>;
|
||||
ByteStorageT weights_u8 = hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
|
||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
|
||||
return weights_u8;
|
||||
}
|
||||
|
||||
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
|
||||
#define GEMMA_CALL_TOP_FUNC2(name, member) \
|
||||
func(name, weights1.member, weights2.member)
|
||||
#define GEMMA_CALL_TOP_FUNC3(name, member) \
|
||||
func(name, weights1.member, weights2.member, weights3.member)
|
||||
#define GEMMA_CALL_TOP_FUNC4(name, member) \
|
||||
func(name, weights1.member, weights2.memeber, \
|
||||
weights3.member, weights4.member)
|
||||
|
||||
#define GEMMA_CALL_LAYER_FUNC1(name, member) \
|
||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||
func(name_buf, layer1.member)
|
||||
|
||||
#define GEMMA_CALL_LAYER_FUNC2(name, member) \
|
||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||
func(name_buf, layer1.member, layer2.member)
|
||||
|
||||
#define GEMMA_CALL_LAYER_FUNC3(name, member) \
|
||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||
func(name_buf, layer1.member, layer2.member, layer3.member)
|
||||
|
||||
#define GEMMA_CALL_LAYER_FUNC4(name, member) \
|
||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||
func(name_buf, layer1.member, layer2.member, layer4.member)
|
||||
|
||||
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \
|
||||
if (type == LayerAttentionType::kGemma) { \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \
|
||||
} else { \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \
|
||||
} \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
|
||||
if (TConfig::kPostNormScale) { \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
|
||||
} \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \
|
||||
if (TConfig::kFFBiases) { \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \
|
||||
} \
|
||||
if (TConfig::kSoftmaxAttnOutputBiases && \
|
||||
type == LayerAttentionType::kGemma) { \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig, class Func>
|
||||
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
|
||||
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
||||
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig, class Func>
|
||||
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
|
||||
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
||||
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig, class Func>
|
||||
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
||||
Weights<T, TConfig>& weights2) {
|
||||
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
|
||||
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
||||
GEMMA_CALL_ALL_LAYER_FUNC(2)
|
||||
}
|
||||
}
|
||||
|
||||
#undef GEMMA_CALL_TOP_FUNC1
|
||||
#undef GEMMA_CALL_TOP_FUNC2
|
||||
#undef GEMMA_CALL_TOP_FUNC3
|
||||
#undef GEMMA_CALL_TOP_FUNC4
|
||||
#undef GEMMA_CALL_LAYER_FUNC1
|
||||
#undef GEMMA_CALL_LAYER_FUNC2
|
||||
#undef GEMMA_CALL_LAYER_FUNC3
|
||||
#undef GEMMA_CALL_LAYER_FUNC4
|
||||
#undef GEMMA_CALL_ALL_LAYER_FUNC
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void ZeroInit(Weights<T, TConfig>& w) {
|
||||
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||
sizeof(w.embedder_input_embedding));
|
||||
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
|
||||
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||
sizeof(src.embedder_input_embedding));
|
||||
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||
sizeof(src.final_norm_scale));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i)));
|
||||
}
|
||||
}
|
||||
|
||||
// Owns weights and undoes the type erasure of AllocateWeights.
|
||||
template<typename T, typename TConfig>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
WeightsWrapper()
|
||||
: pool_(0), data_(AllocateWeights<T, TConfig>(pool_)),
|
||||
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
||||
|
||||
const Weights<T, TConfig>& get() const { return *weights_; }
|
||||
Weights<T, TConfig>& get() { return *weights_; }
|
||||
void clear() { ZeroInit(get()); }
|
||||
void copy(const WeightsWrapper<T, TConfig>& other) {
|
||||
Copy(get(), other.get());
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
ByteStorageT data_;
|
||||
Weights<T, TConfig>* weights_;
|
||||
};
|
||||
|
||||
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool);
|
||||
|
||||
void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool);
|
||||
|
||||
void LogWeightStats(Model model, const ByteStorageT& weights);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
Loading…
Reference in New Issue