Remove backprop/

Also remove MatPtrT::Packed(); use PackedScale1 instead where const, or Row(0).

PiperOrigin-RevId: 764243198
This commit is contained in:
Jan Wassenberg 2025-05-28 07:00:44 -07:00 committed by Copybara-Service
parent 627cc04db9
commit 3890eb5412
40 changed files with 62 additions and 3078 deletions

View File

@ -685,131 +685,3 @@ cc_binary(
"@nlohmann_json//:json",
],
)
cc_library(
name = "prompt",
hdrs = ["backprop/prompt.h"],
deps = [],
)
cc_library(
name = "sampler",
hdrs = ["backprop/sampler.h"],
deps = [
":prompt",
],
)
cc_library(
name = "backprop",
srcs = [
"backprop/backward.cc",
"backprop/forward.cc",
],
hdrs = [
"backprop/activations.h",
"backprop/backward.h",
"backprop/forward.h",
],
textual_hdrs = [
"backprop/backward-inl.h",
"backprop/forward-inl.h",
],
deps = [
":allocator",
":common",
":configs",
":mat",
":ops",
":prompt",
":weights",
"@highway//:dot",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
)
cc_library(
name = "backprop_scalar",
hdrs = [
"backprop/activations.h",
"backprop/backward_scalar.h",
"backprop/common_scalar.h",
"backprop/forward_scalar.h",
],
deps = [
":common",
":configs",
":mat",
":prompt",
":weights",
"@highway//:hwy",
],
)
cc_test(
name = "backward_test",
size = "large",
srcs = [
"backprop/backward_test.cc",
"backprop/test_util.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":backprop",
":backprop_scalar",
":configs",
":mat",
":ops",
":prompt",
":sampler",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
cc_library(
name = "optimizer",
srcs = ["backprop/optimizer.cc"],
hdrs = ["backprop/optimizer.h"],
deps = [
":allocator",
":mat",
":weights",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_test(
name = "optimize_test",
srcs = ["backprop/optimize_test.cc"],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":allocator",
":backprop",
":basics",
":configs",
":gemma_lib",
":ops",
":optimizer",
":prompt",
":sampler",
":threading",
":tokenizer",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:types",
"@highway//:thread_pool",
],
)

View File

@ -41,18 +41,6 @@ FetchContent_MakeAvailable(benchmark)
# Base source files
set(SOURCES
backprop/activations.h
backprop/backward_scalar.h
backprop/backward-inl.h
backprop/backward.cc
backprop/backward.h
backprop/common_scalar.h
backprop/forward_scalar.h
backprop/forward-inl.h
backprop/forward.cc
backprop/forward.h
backprop/optimizer.cc
backprop/optimizer.h
compression/compress-inl.h
compression/compress.cc
compression/compress.h
@ -202,8 +190,6 @@ enable_testing()
include(GoogleTest)
set(GEMMA_TEST_FILES
backprop/backward_test.cc
backprop/optimize_test.cc
compression/compress_test.cc
compression/distortion_test.cc
compression/nuq_test.cc

View File

@ -173,9 +173,6 @@ more custom you can call transformer which performs a single inference operation
on a single token and mutates the Activations and the KVCache through the neural
network computation.
Note that an experimental backward pass is available in backprop/, which may be
useful for fine tuning.
### For low level operations, defining new architectures, call `ops.h` functions directly
You use `ops.h` if you're writing other NN architectures or modifying the

View File

@ -1,87 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
#include <stddef.h>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "util/mat.h" // MatStorageT
namespace gcpp {
template <typename T>
struct ForwardLayer {
ForwardLayer(const LayerConfig& config, size_t seq_len)
: input(MakePacked<T>("input", seq_len, config.model_dim)),
pre_att_rms_out(
MakePacked<T>("pre_att_rms_out", seq_len, config.model_dim)),
qkv(MakePacked<T>("qkv", seq_len * (config.heads + 2), config.qkv_dim)),
att(MakePacked<T>("att", seq_len * config.heads, seq_len)),
att_out(
MakePacked<T>("att_out", seq_len * config.heads, config.qkv_dim)),
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
attention_out(
MakePacked<T>("attention_out", seq_len, config.model_dim)),
pre_ffw_rms_out(
MakePacked<T>("preFF_rms_out", seq_len, config.model_dim)),
ffw_hidden(
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
ffw_hidden_gated(
MakePacked<T>("ffw_hidden_gated", seq_len, config.ff_hidden_dim)),
layer_config(config) {}
MatStorageT<T> input;
MatStorageT<T> pre_att_rms_out;
MatStorageT<T> qkv;
MatStorageT<T> att;
MatStorageT<T> att_out;
MatStorageT<T> att_post1;
MatStorageT<T> attention_out;
MatStorageT<T> pre_ffw_rms_out;
MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated;
const LayerConfig& layer_config;
};
template <typename T>
struct ForwardPass {
ForwardPass(const ModelConfig& config)
: final_layer_output(
MakePacked<T>("fin_layer_out", config.seq_len, config.model_dim)),
final_norm_output(
MakePacked<T>("fin_norm_out", config.seq_len, config.model_dim)),
logits(MakePacked<T>("logits", config.seq_len, config.vocab_size)),
probs(MakePacked<T>("probs", config.seq_len, config.vocab_size)),
weights_config(config) {
for (const auto& layer_config : config.layer_configs) {
layers.emplace_back(layer_config, config.seq_len);
}
}
std::vector<ForwardLayer<T>> layers;
MatStorageT<T> final_layer_output;
MatStorageT<T> final_norm_output;
MatStorageT<T> logits;
MatStorageT<T> probs;
const ModelConfig& weights_config;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_

View File

@ -1,407 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Implementation of the Vector-Jacobian Products (VJP) of the individual
// operations of the forward pass.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#include <stddef.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/configs.h" // LayerConfig, ModelConfig
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
#include "hwy/contrib/dot/dot-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t cols, size_t rows, size_t num_tokens,
float* HWY_RESTRICT grad_w, // kRows * kCols,
float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t voffs = pos * rows;
const size_t xoffs = pos * cols;
for (size_t j = 0; j < rows; ++j) {
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols);
MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols);
}
}
}
HWY_INLINE void MultiHeadMatMulVJP(
const float* HWY_RESTRICT weights, // heads * kRows * kCols
const float* HWY_RESTRICT x, // num_tokens * heads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t heads, size_t cols, size_t rows, size_t num_tokens,
float* HWY_RESTRICT grad_w, // heads * kRows * kCols
float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t j = 0; j < rows; ++j) {
for (size_t h = 0; h < heads; ++h) {
MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols],
&grad_w[h * rows * cols + j * cols], cols);
MulByConstAndAdd(v[pos * rows + j],
&weights[h * rows * cols + j * cols],
&grad_x[pos * heads * cols + h * cols], cols);
}
}
}
}
template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
// kSqrtOverPi*3*kMul
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
const hn::Vec<D> v2 = hn::Mul(v, v);
const hn::Vec<D> v3 = hn::Mul(v2, v);
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
const hn::Vec<D> tanh = hn::Tanh(d, arg);
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
}
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto offset =
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
hn::Transform1(
d, backward, size, forward,
[&offset](const auto d, const auto v, const auto y)
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP(
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
const float ss = detail::RMSNormMul(x + offset, model_dim);
for (size_t i = 0; i < model_dim; ++i) {
grad_w[i] += v[offset + i] * x[offset + i] * ss;
}
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
float tmp = 0.0f;
for (size_t i = 0; i < model_dim; ++i) {
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
}
tmp *= ss3;
for (size_t i = 0; i < model_dim; ++i) {
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
tmp * x[offset + i];
}
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt, const float scaling,
const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) {
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
MulByConstAndAdd(scaling, v + pos * model_dim,
grad + token * model_dim, model_dim);
}
}
template <typename T>
void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<float>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
const MatStorageT<float>& inv_timescale, hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim;
const size_t qkv_dim = config.qkv_dim;
const size_t heads = config.heads;
const size_t seq_len = forward.input.Rows();
const size_t ff_hidden_dim = config.ff_hidden_dim;
const float query_scale =
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
HWY_ASSERT(num_tokens <= seq_len);
MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(),
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * ff_hidden_dim * 2;
const float* HWY_RESTRICT f_out =
forward.ffw_hidden.Packed() + hidden_offset;
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
const float* HWY_RESTRICT b_out_gated =
backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset;
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
DF df;
for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) {
const auto y = Load(df, f_out + i);
const auto x = Load(df, f_out_mul + i);
const auto v = Load(df, b_out_gated + i);
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
}
}
MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2,
num_tokens, grad.gating_einsum_w.Packed(),
backward.pre_ffw_rms_out.Packed(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.Packed(),
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(),
backward.attention_out.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * model_dim,
backward.attention_out.Packed() + pos * model_dim, model_dim);
}
ZeroInit(backward.qkv);
MultiHeadMatMulVJP(
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens,
grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool);
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
const float* HWY_RESTRICT b_att_out =
backward.att_out.Packed() + (pos * heads + head) * qkv_dim;
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs;
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
}
}
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
}
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
const size_t aoffs = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs;
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
}
}
}
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos);
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_q =
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
MulByConst(query_scale, b_q, qkv_dim);
Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos);
}
}
MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens,
grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(),
pool);
RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
backward.pre_att_rms_out.Packed(), model_dim, num_tokens,
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(backward.attention_out.Packed() + pos * model_dim,
backward.input.Packed() + pos * model_dim, model_dim);
}
}
static HWY_NOINLINE void SoftcapVJP(const float cap,
const float* HWY_RESTRICT forward,
float* HWY_RESTRICT backward,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
const auto one = hn::Set(d, 1.0f);
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform1(d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
}
static HWY_NOINLINE void CrossEntropyLossGrad(
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
const Prompt& prompt, size_t vocab_size) {
HWY_ASSERT(!prompt.tokens.empty());
const float scaling = -1.0 / std::log(2.0);
size_t num_tokens = prompt.tokens.size() - 1;
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
if (pos + 1 < prompt.context_size) {
continue;
}
const int next_token = prompt.tokens[pos + 1];
grad[pos * vocab_size + next_token] =
scaling / x[pos * vocab_size + next_token];
}
}
template <typename T>
void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
const ModelWeightsPtrs<T>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<T>& grad,
ForwardPass<float>& backward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config;
const size_t kVocabSize = config.vocab_size;
const size_t model_dim = config.model_dim;
const size_t kLayers = config.layer_configs.size();
const float kEmbScaling = EmbeddingScaling(model_dim);
HWY_ASSERT(!config.absolute_pe);
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
HWY_DASSERT(prompt.context_size > 0);
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
const size_t num_tokens = prompt.tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
kVocabSize);
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize,
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
}
if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize,
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
}
}
MatMulVJP(weights.embedder_input_embedding.Packed(),
forward.final_norm_output.Packed(), backward.logits.Packed(),
model_dim, kVocabSize, num_tokens,
grad.embedder_input_embedding.Packed(),
backward.final_norm_output.Packed(), pool);
RMSNormVJP(weights.final_norm_scale.Packed(),
forward.final_layer_output.Packed(),
backward.final_norm_output.Packed(), model_dim, num_tokens,
grad.final_norm_scale.Packed(),
backward.final_layer_output.Packed(), pool);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
auto layer_config = config.layer_configs[layer];
// TODO(szabadka) Implement Griffin layer vjp.
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.Packed()
: backward.final_layer_output.Packed();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
inv_timescale, pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens,
kEmbScaling, backward.layers[0].input.Packed(),
grad.embedder_input_embedding.Packed(), model_dim);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

View File

@ -1,72 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/backward.h"
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "backprop/backward-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void CrossEntropyLossBackwardPassT(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
inv_timescale, pool);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossBackwardPassT);
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
prompt, weights, forward, grad, backward, inv_timescale, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -1,37 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_

View File

@ -1,352 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
namespace gcpp {
template<typename T>
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t M, size_t K) {
memset(dx, 0, M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
}
}
}
template<typename T>
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t H, size_t N, size_t M, size_t K) {
memset(dx, 0, H * M * K * sizeof(dx[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
for (size_t h = 0; h < H; ++h) {
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
&dw[h * N * M + j * M], M);
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
&dx[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
constexpr T eps(1e-6);
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; ++j) {
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
}
const T ss3 = ss * ss * ss / T(N);
T tmp = 0.0;
for (size_t j = 0; j < N; ++j) {
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
}
tmp *= ss3;
for (size_t j = 0; j < N; ++j) {
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
}
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += y[i] * dy[i];
}
for (size_t i = 0; i < N; ++i) {
dy[i] = y[i] * (dy[i] - sum);
}
}
template<typename T>
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
SoftmaxVJPT(y + i * N, dy + i * N, N);
}
}
template<typename T>
T GeluDerivative(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
const T x2 = x * x;
const T x3 = x2 * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T tanh = std::tanh(arg);
const T cdf = T(0.5) * (T(1.0) + tanh);
const T dtanh = T(1.0) - tanh * tanh;
const T darg = kMul2 * x2 + kSqrt2OverPi;
return T(0.5) * x * dtanh * darg + cdf;
}
template<typename T>
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
const T* v = d_out + i * N;
T* dx1 = d_in + i * 2 * N;
T* dx2 = dx1 + N;
for (size_t j = 0; j < N; ++j) {
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
dx2[j] = v[j] * Gelu(x1[j]);
}
}
}
template <typename T>
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
size_t num_tokens, size_t kHeads, size_t qkv_dim,
size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * (kHeads + 2) * qkv_dim;
memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim;
const size_t aoffs = head * seq_len + pos * kHeads * seq_len;
const T* q = qkv + qoffs;
const T* dout = doutput + aoffs;
T* dq = dqkv + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim;
const T* k = qkv + koffs;
T* dk = dqkv + koffs;
MulByConstAndAddT(dout[pos2], k, dq, qkv_dim);
MulByConstAndAddT(dout[pos2], q, dk, qkv_dim);
}
}
}
}
template <typename T>
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads,
size_t seq_len) {
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
size_t offset = pos * kHeads * seq_len + head * seq_len;
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
}
}
}
template <typename T>
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
T* dqkv, T* dattention, size_t num_tokens, size_t kHeads,
size_t qkv_dim, size_t seq_len) {
auto v_offset = [&](size_t pos) {
return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim;
};
for (size_t pos = 0; pos < num_tokens; ++pos) {
memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0]));
}
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim;
const size_t aoffset = head * seq_len + pos * kHeads * seq_len;
const T* att = &attention[aoffset];
const T* dout = &doutput[offset];
T* datt = &dattention[aoffset];
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim);
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim);
}
}
}
}
template<typename T>
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
const T* dy, T* dw, size_t N) {
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i];
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
}
}
template <typename T>
void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<T>& forward, const T* dy,
LayerWeightsPtrs<T>& grad, ForwardLayer<T>& backward,
size_t num_tokens) {
const LayerConfig& layer_config = weights.layer_config;
const size_t model_dim = layer_config.model_dim;
const size_t seq_len = forward.input.Rows();
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kHeads = layer_config.heads;
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy,
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(),
model_dim, kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(),
backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(),
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(),
model_dim, num_tokens);
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);
MultiHeadMatMulVJPT(
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(),
backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens);
MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(),
backward.att_out.Packed(), backward.qkv.Packed(),
backward.att.Packed(), num_tokens, kHeads, qkv_dim,
seq_len);
MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens,
kHeads, seq_len);
MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(),
backward.qkv.Packed(), num_tokens, kHeads, qkv_dim,
seq_len);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
}
for (int pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
}
}
MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
backward.qkv.Packed(), grad.qkv_einsum_w.Packed(),
backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim,
model_dim, num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
backward.pre_att_rms_out.Packed(),
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
model_dim, num_tokens);
AddFromT(backward.attention_out.Packed(), backward.input.Packed(),
num_tokens * model_dim);
}
template <typename T>
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
const T inv_cap = T{1.0} / static_cast<T>(cap);
for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap; // tanh
dy[i] *= (T{1.0} - scaled * scaled);
}
}
template<typename T>
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
T scaling = -1.0 / std::log(2.0);
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
memset(dx, 0, V * num_tokens * sizeof(x[0]));
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) {
continue;
}
const int next_token = tokens[i + 1];
dx[i * V + next_token] = scaling / x[i * V + next_token];
}
}
template <typename T>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ModelWeightsPtrs<T>& weights,
const ForwardPass<T>& forward,
ModelWeightsPtrs<T>& grad,
ForwardPass<T>& backward) {
const ModelConfig& config = weights.weights_config;
const size_t model_dim = config.model_dim;
const size_t vocab_size = config.vocab_size;
const size_t layers = config.layer_configs.size();
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
vocab_size);
SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size,
num_tokens);
if (config.final_cap > 0.0f) {
for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size,
backward.logits.Packed() + i * vocab_size, vocab_size);
}
}
MatMulVJPT(weights.embedder_input_embedding.Packed(),
forward.final_norm_output.Packed(), backward.logits.Packed(),
grad.embedder_input_embedding.Packed(),
backward.final_norm_output.Packed(), vocab_size, model_dim,
num_tokens);
RMSNormVJPT(
weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(),
backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(),
backward.final_layer_output.Packed(), model_dim, num_tokens);
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
T* next_layer_grad = layer + 1 < layers
? backward.layers[layer + 1].input.Packed()
: backward.final_layer_output.Packed();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
}
const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens,
kEmbScaling, backward.layers[0].input.Packed(),
grad.embedder_input_embedding.Packed(), model_dim);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_

View File

@ -1,250 +0,0 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <complex>
#include <cstdlib> // std::abs
#include <random>
#include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h" // DotT
#include "backprop/forward_scalar.h" // MatMulT
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gemma/configs.h"
#include "ops/ops.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "backprop/forward-inl.h"
#include "ops/ops-inl.h"
// 'include guard' so we only define this once. Note that HWY_ONCE is only
// defined during the last pass, but this is used in each pass.
#ifndef BACKWARD_TEST_ONCE
#define BACKWARD_TEST_ONCE
// TestEndToEnd is slow, so only run it for the best-available target.
static int run_once;
#endif
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
hwy::ThreadPool& ThreadHostileGetPool() {
// Assume this is only called at the top level, i.e. not in a thread. Then we
// can safely call `SetArgs` only once, because it would assert otherwise.
// This is preferable to calling `ThreadHostileInvalidate`, because we would
// repeat the topology initialization for every test.
if (!ThreadingContext::IsInitialized()) {
gcpp::ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 8;
threading_args.pin = Tristate::kFalse;
ThreadingContext::SetArgs(threading_args);
}
return ThreadingContext::Get().pools.Pool();
}
void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
hwy::ThreadPool& pool = ThreadHostileGetPool();
std::mt19937 gen(42);
auto weights = MakePacked<float>("weights", kRows, kCols);
auto x = MakePacked<float>("x", kTokens, kCols);
auto dy = MakePacked<float>("dy", kTokens, kRows);
auto grad = MakePacked<float>("grad", kRows, kCols);
auto dx = MakePacked<float>("dx", kTokens, kCols);
using TC = std::complex<double>;
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
kTokens);
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
};
ZeroInit(grad);
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
grad.Packed(), dx.Packed(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
}
}
void TestMultiHeadMatMulVJP() {
static const size_t kRows = 2;
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
hwy::ThreadPool& pool = ThreadHostileGetPool();
std::mt19937 gen(42);
auto weights = MakePacked<float>("weights", kRows, kCols * kHeads);
auto x = MakePacked<float>("x", kTokens, kCols * kHeads);
auto grad = MakePacked<float>("grad", kRows, kCols * kHeads);
auto dx = MakePacked<float>("dx", kTokens, kCols * kHeads);
auto dy = MakePacked<float>("dy", kTokens, kRows);
using TC = std::complex<double>;
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
kRows, kCols, kTokens);
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
};
ZeroInit(grad);
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
}
}
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
hwy::ThreadPool& pool = ThreadHostileGetPool();
std::mt19937 gen(42);
auto weights = MakePacked<float>("weights", N, 1);
auto x = MakePacked<float>("x", K, N);
auto grad = MakePacked<float>("grad", N, 1);
auto dx = MakePacked<float>("dx", K, N);
auto dy = MakePacked<float>("dy", K, N);
using TC = std::complex<double>;
auto c_weights = MakePacked<TC>("c_weights", N, 1);
auto c_x = MakePacked<TC>("c_x", K, N);
auto c_y = MakePacked<TC>("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
RandInit(x, 1.0f * (1 << iter), gen);
RandInit(dy, 1.0f, gen);
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
return DotT(dy.Packed(), c_y.Packed(), K * N);
};
ZeroInit(grad);
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
dx.Packed(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
}
}
void TestEndToEnd() {
if (++run_once > 1) return; // ~3 min on SKX, only run best available target
std::mt19937 gen(42);
hwy::ThreadPool& pool = ThreadHostileGetPool();
ModelConfig config(Model::GEMMA_TINY, Type::kF32, PromptWrapping::GEMMA_IT);
WeightsWrapper<float> weights(config);
WeightsWrapper<float> grad(config);
ForwardPass<float> forward0(config);
ForwardPass<float> forward1(config);
ForwardPass<float> backward(config);
using TC = std::complex<double>;
WeightsWrapper<TC> c_weights(config);
ForwardPass<TC> c_forward(config);
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
MatStorageT<float> inv_timescale = CreateInvTimescale(
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
weights.get().RandInit(1.0f, gen);
float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
float loss1 = CrossEntropyLossForwardPass(
prompt.tokens, prompt.context_size, weights.get(), forward1,
inv_timescale, pool);
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.get().ZeroInit();
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
backward, inv_timescale, pool);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
};
TestGradient(grad.get(), c_weights.get(), func, 2e-3f, __LINE__);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(BackwardTest);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -1,123 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#include <stddef.h>
#include <complex>
#include "util/mat.h"
namespace gcpp {
template<typename T, typename U>
U DotT(const T* a, const U* b, size_t N) {
U sum = {};
for (size_t i = 0; i < N; ++i) {
sum += a[i] * b[i];
}
return sum;
}
template<>
inline std::complex<double> DotT(const float* a, const std::complex<double>* b,
size_t N) {
std::complex<double> sum = {};
for (size_t i = 0; i < N; ++i) {
sum += static_cast<double>(a[i]) * b[i];
}
return sum;
}
template<typename T>
void MulByConstT(T c, T* x, size_t N) {
for (size_t i = 0; i < N; ++i) {
x[i] *= c;
}
}
// out += c * x
template<typename T>
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += c * x[i];
}
}
template <typename T>
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
for (size_t r = 0; r < x.Rows(); ++r) {
MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols());
}
}
template<typename T>
void AddFromT(const T* a, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += a[i];
}
}
template<typename T>
T SquaredL2(const T* x, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += x[i] * x[i];
}
return sum;
}
template<typename T>
T Gelu(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
const T x3 = x * x * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
return x * cdf;
}
template<typename T, typename U>
void Rope(T* x, U base, size_t N, int i) {
const size_t N2 = N / 2;
for (size_t dim = 0; dim < N2; ++dim) {
const T freq_exponents = T(2 * dim) / T(N);
const T timescale = std::pow(base, freq_exponents);
const T theta = T(i) / timescale;
const T cos_val = std::cos(theta);
const T sin_val = std::sin(theta);
const T x0 = x[dim];
const T x1 = x[dim + N2];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + N2] = x0 * sin_val + x1 * cos_val;
}
}
template<typename T>
void Rope(T* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
template<typename T>
void Rope(std::complex<T>* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_

View File

@ -1,297 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename T>
void InputEmbedding(const MatPtrT<T>& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output,
size_t model_dim, size_t vocab_size) {
const hn::ScalableTag<float> df;
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
const auto span = weights.Span();
HWY_ASSERT(span.num == model_dim * vocab_size);
DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim,
model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim);
}
}
template<typename WT, typename XT, typename OutT>
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
size_t model_dim, size_t num_tokens,
OutT* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, 0, output + offset, model_dim);
}
}
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
const std::vector<int>& prompt,
size_t context_size,
size_t vocab_size,
hwy::ThreadPool& pool) {
HWY_ASSERT(!prompt.empty());
float loss = 0.0f;
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
if (pos + 1 < context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = prompt[pos + 1];
loss += std::log(probs[pos * vocab_size + next_token]);
}
float scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename T>
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<float>& activations, size_t num_tokens,
float* HWY_RESTRICT output,
const MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim;
const size_t kSeqLen = activations.input.Rows();
const size_t kQKVDim = config.qkv_dim;
const size_t kHeads = config.heads;
static const float query_scale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(),
activations.input.Packed(), model_dim, num_tokens,
activations.pre_att_rms_out.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
activations.pre_att_rms_out.Packed() + pos * model_dim,
activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos);
MulByConst(query_scale, q, kQKVDim);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT q =
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
float* HWY_RESTRICT head_att =
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const float* HWY_RESTRICT k2 =
activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT head_att =
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
Softmax(head_att, pos + 1);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT head_att =
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
float* HWY_RESTRICT att_out =
activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
float* HWY_RESTRICT v2 = activations.qkv.Packed() +
(pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
});
ZeroInit(activations.attention_out);
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
kQKVDim,
activations.att_out.Packed() + pos * kHeads * kQKVDim +
head * kQKVDim,
activations.att_post1.Packed() + pos * model_dim, pool);
AddFrom(activations.att_post1.Packed() + pos * model_dim,
activations.attention_out.Packed() + pos * model_dim, model_dim);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.Packed() + pos * model_dim,
activations.attention_out.Packed() + pos * model_dim, model_dim);
}
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(), model_dim, num_tokens,
activations.pre_ffw_rms_out.Packed(), pool);
const size_t kFFHiddenDim = config.ff_hidden_dim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
activations.pre_ffw_rms_out.Packed() + pos * model_dim,
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT out =
activations.ffw_hidden.Packed() + hidden_offset;
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
float* HWY_RESTRICT out_gated =
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = hn::Load(df, out + i);
const auto x = hn::Load(df, out_mul + i);
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim,
output + pos * model_dim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.Packed() + pos * model_dim,
output + pos * model_dim, model_dim);
}
}
template <typename T>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const ModelWeightsPtrs<T>& weights,
ForwardPass<float>& forward,
const MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config;
const size_t vocab_size = config.vocab_size;
const size_t model_dim = config.model_dim;
const size_t layers = config.layer_configs.size();
const float emb_scaling = EmbeddingScaling(model_dim);
HWY_ASSERT(!config.absolute_pe);
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
HWY_DASSERT(context_size > 0);
HWY_DASSERT(context_size < prompt.size());
const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
forward.layers[0].input.Packed(), model_dim, vocab_size);
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
auto type = config.layer_configs[layer].type;
// TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < layers
? forward.layers[layer + 1].input.Packed()
: forward.final_layer_output.Packed();
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, inv_timescale, pool);
}
ApplyRMSNorm(weights.final_norm_scale.Packed(),
forward.final_layer_output.Packed(), model_dim, num_tokens,
forward.final_norm_output.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
forward.final_norm_output.Packed() + pos * model_dim,
forward.logits.Packed() + pos * vocab_size, pool);
}
if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(config.final_cap,
forward.logits.Packed() + pos * vocab_size, vocab_size);
}
}
CopyMat(forward.logits, forward.probs);
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size);
}
return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size,
vocab_size, pool);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

View File

@ -1,66 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/forward.h"
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "backprop/forward-inl.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
float CrossEntropyLossForwardPassT(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
weights, forward, inv_timescale, pool);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
prompt, weights, forward, inv_timescale, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -1,35 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_

View File

@ -1,298 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <complex>
#include <vector>
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
#include "hwy/base.h"
namespace gcpp {
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
// y = w * x is N x K matrix in column-major order.
template<typename T>
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
}
}
}
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
// column-major order and y = w' * x is N x K matrix in column-major order,
// where w' is the rearrangement of w into an N x HM matrix.
template<typename T>
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
size_t M, size_t K) {
memset(y, 0, N * K * sizeof(y[0]));
for (size_t i = 0; i < K; ++i) {
for (size_t h = 0; h < H; ++h) {
for (size_t j = 0; j < N; ++j) {
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
}
}
}
}
template<typename T>
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
constexpr T eps(1e-6);
for (size_t i = 0; i < K; ++i) {
T ss = SquaredL2(x + i * N, N);
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
for (size_t j = 0; j < N; j++) {
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
}
}
}
template<typename T>
void Softmax(T* x, size_t N) {
T sum = {};
auto maxreal = std::real(x[0]);
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
}
}
for (size_t i = 0; i < N; ++i) {
x[i] = std::exp(x[i] - maxreal);
sum += x[i];
}
T scale = T(1.0) / sum;
for (size_t i = 0; i < N; ++i) {
x[i] *= scale;
}
}
template<typename T>
void Softmax(T* x, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
Softmax(x + i * N, N);
}
}
template <typename T>
void Softcap(float cap, T* x, size_t N) {
const T inv_cap = T{1.0} / static_cast<T>(cap);
for (size_t i = 0; i < N; ++i) {
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
}
}
template<typename T>
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
for (size_t i = 0; i < K; ++i) {
const T* x1 = in + i * 2 * N;
const T* x2 = x1 + N;
T* y = out + i * N;
for (size_t j = 0; j < N; ++j) {
y[j] = x2[j] * Gelu(x1[j]);
}
}
}
template<typename T>
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
T* y, size_t N) {
HWY_ASSERT(w != nullptr);
HWY_ASSERT(y != nullptr);
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
int token = tokens[i];
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
MulByConstT(scaling, y + i * N, N);
}
}
template <typename T>
void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads,
size_t qkv_dim, size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < heads; ++head) {
const size_t qoffset = pos * (heads + 2) * qkv_dim;
const size_t aoffset = pos * heads * seq_len + head * seq_len;
const T* q = qkv + qoffset + head * qkv_dim;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim;
output[aoffset + pos2] = DotT(q, k, qkv_dim);
}
}
}
}
template <typename T>
void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < heads; ++head) {
size_t offset = pos * heads * seq_len + head * seq_len;
Softmax(x + offset, pos + 1);
memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
}
}
}
template <typename T>
void MixByAttention(const T* qkv, const T* attention, T* output,
size_t num_tokens, size_t heads, size_t qkv_dim,
size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < heads; ++head) {
const T* att = &attention[pos * heads * seq_len + head * seq_len];
T* out = &output[head * qkv_dim + pos * heads * qkv_dim];
memset(out, 0, qkv_dim * sizeof(out[0]));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const T* v = &qkv[v_offset];
MulByConstAndAddT(att[pos2], v, out, qkv_dim);
}
}
}
}
template <typename T>
void ApplyLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<T>& activations, size_t num_tokens, T* output) {
const LayerConfig& layer_config = weights.layer_config;
const size_t model_dim = layer_config.model_dim;
const size_t seq_len = activations.input.Rows();
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
RMSNormT(weights.pre_attention_norm_scale.Packed(),
activations.input.Packed(), activations.pre_att_rms_out.Packed(),
model_dim, num_tokens);
MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(),
activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim,
num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
for (size_t h = 0; h <= heads; ++h) {
Rope(qkv + h * qkv_dim, qkv_dim, pos);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
MulByConstT(query_scale, qkv, heads * qkv_dim);
}
MaskedAttention(activations.qkv.Packed(), activations.att.Packed(),
num_tokens, heads, qkv_dim, seq_len);
MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len);
MixByAttention(activations.qkv.Packed(), activations.att.Packed(),
activations.att_out.Packed(), num_tokens, heads, qkv_dim,
seq_len);
MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(),
activations.att_out.Packed(),
activations.attention_out.Packed(), heads, model_dim, qkv_dim,
num_tokens);
AddFromT(activations.input.Packed(), activations.attention_out.Packed(),
num_tokens * model_dim);
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(),
activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens);
MatMulT(weights.gating_einsum_w.Packed(),
activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(),
ff_hidden_dim * 2, model_dim, num_tokens);
GatedGelu(activations.ffw_hidden.Packed(),
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);
MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(),
output, model_dim, ff_hidden_dim, num_tokens);
AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim);
}
template<typename T>
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
T loss = {};
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
for (size_t i = 0; i < num_tokens; ++i) {
if (i + 1 < prompt.context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = tokens[i + 1];
loss += std::log(x[i * V + next_token]);
}
T scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename T>
T CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<T>& weights,
ForwardPass<T>& forward) {
const ModelConfig& config = weights.weights_config;
const size_t model_dim = config.model_dim;
const size_t vocab_size = config.vocab_size;
const size_t layers = config.layer_configs.size();
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling,
forward.layers[0].input.Packed(), model_dim);
for (size_t layer = 0; layer < layers; ++layer) {
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed()
: forward.final_layer_output.Packed();
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
output);
}
RMSNormT(weights.final_norm_scale.Packed(),
forward.final_layer_output.Packed(),
forward.final_norm_output.Packed(), model_dim, num_tokens);
MatMulT(weights.embedder_input_embedding.Packed(),
forward.final_norm_output.Packed(), forward.logits.Packed(),
vocab_size, model_dim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
if (config.final_cap > 0.0f) {
Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size,
vocab_size);
}
}
CopyMat(forward.logits, forward.probs);
Softmax(forward.probs.Packed(), vocab_size, num_tokens);
return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_

View File

@ -1,154 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <algorithm>
#include <cstdio>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "backprop/activations.h"
#include "backprop/backward.h"
#include "backprop/forward.h"
#include "backprop/optimizer.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "compression/types.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h"
#include "ops/ops.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
gcpp::ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.pin = Tristate::kFalse;
ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext::Get());
const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool();
std::mt19937 gen(42);
ModelConfig config(Model::GEMMA_TINY, Type::kF32,
ChooseWrapping(Model::GEMMA_TINY));
config.eos_id = ReverseSequenceSampler::kEndToken;
WeightsOwner grad(Type::kF32), grad_m(Type::kF32), grad_v(Type::kF32);
grad.AllocateForTest(config, pool);
grad_m.AllocateForTest(config, pool);
grad_v.AllocateForTest(config, pool);
grad_m.ZeroInit();
grad_v.ZeroInit();
ForwardPass<float> forward(config), backward(config);
KVCache kv_cache(config, /*prefill_tbatch_size=*/16);
MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
Gemma gemma(config, GemmaTokenizer(kMockTokenizer), env);
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
auto stream_token = [&reply](int token, float) {
reply.push_back(token);
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
.max_generated_tokens = 16,
.temperature = 1.0f,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
return reply;
};
// Sanity check of reply tokens.
// 1) Its length should be greater than the prompt.
// 2) The prompt should be a prefix of the reply.
auto verify = [&](const Prompt& prompt) {
const std::vector<int>& context = prompt.context();
std::vector<int> reply = generate(context);
if (reply.size() <= context.size()) return false;
return std::equal(context.begin(), context.end(), reply.begin(),
reply.begin() + context.size());
};
gemma.MutableWeights().RandInit(1.0f, gen);
gemma.MutableWeights().Fixup(pool);
printf("Initial weights:\n");
gemma.MutableWeights().LogWeightStatsF32();
constexpr size_t kBatchSize = 8;
constexpr float kAlpha = 0.001f;
constexpr float kBeta1 = 0.9f;
constexpr float kBeta2 = 0.999f;
constexpr float kEpsilon = 1e-8f;
constexpr float kMaxLoss = 20.0f;
ReverseSequenceSampler training_task({
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
size_t steps = 0;
size_t num_ok;
for (; steps < 1000; ++steps) {
std::mt19937 sgen(42);
grad.ZeroInit();
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(
prompt, *gemma.Weights().GetF32(), forward, inv_timescale, pool);
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
*grad.GetF32(), backward, inv_timescale,
pool);
gemma.MutableWeights().Fixup(pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;
AdamUpdate(grad, kAlpha, kBeta1, kBeta2, kEpsilon, steps + 1,
gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
printf("Batch gradient:\n");
grad.LogWeightStatsF32();
}
if (total_loss < kMaxLoss) break; // Done
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
gemma.MutableWeights().LogWeightStatsF32();
EXPECT_LT(steps, 80);
EXPECT_EQ(num_ok, kBatchSize);
}
} // namespace gcpp

View File

@ -1,130 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "backprop/optimizer.h"
#include <cmath>
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
using MatPtrF = MatPtrT<float>;
// Split into two classes so that ForEachTensor only requires two "other"
// arguments. This is anyway useful for locality, because `grad` only feeds
// into `grad_m` and `grad_v` here.
class AdamUpdateMV {
public:
AdamUpdateMV(float beta1, float beta2, size_t t)
: beta1_(beta1),
beta2_(beta2),
cbeta1_(1.0f - beta1),
cbeta2_(1.0f - beta2),
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) {
for (size_t r = 0; r < grad.Rows(); ++r) {
const float* HWY_RESTRICT g = grad.Row(r);
float* HWY_RESTRICT m = grad_m.Row(r);
float* HWY_RESTRICT v = grad_v.Row(r);
for (size_t c = 0; c < grad.Cols(); ++c) {
m[c] *= beta1_;
m[c] += cbeta1_ * g[c];
v[c] *= beta2_;
v[c] += cbeta2_ * g[c] * g[c];
}
}
}
private:
float beta1_;
float beta2_;
float cbeta1_;
float cbeta2_;
float norm1_;
float norm2_;
};
// Updates `weights` based on the updated `grad_m` and `grad_v` from above.
class AdamUpdateW {
public:
AdamUpdateW(float alpha, float beta1, float beta2, float epsilon, size_t t)
: alpha_(alpha),
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
epsilon_(epsilon) {}
void operator()(MatPtrF& weights, const MatPtrF& grad_m,
const MatPtrF& grad_v) {
for (size_t r = 0; r < weights.Rows(); ++r) {
float* HWY_RESTRICT w = weights.Row(r);
const float* HWY_RESTRICT m = grad_m.Row(r);
const float* HWY_RESTRICT v = grad_v.Row(r);
for (size_t c = 0; c < weights.Cols(); ++c) {
const float mhat = m[c] * norm1_;
const float vhat = v[c] * norm2_;
w[c] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
}
}
}
private:
float alpha_;
float norm1_;
float norm2_;
float epsilon_;
};
void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
float beta2, float epsilon, size_t t,
ModelWeightsPtrs<float>* weights,
ModelWeightsPtrs<float>* grad_m,
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
AdamUpdateMV update_mv(beta1, beta2, t);
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
const MatPtrF grad_f(t.mat);
MatPtrF grad_m_f(*t.other_mat1);
MatPtrF grad_v_f(*t.other_mat2);
update_mv(grad_f, grad_m_f, grad_v_f);
});
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
MatPtrF weights_f(t.mat);
const MatPtrF grad_m_f(*t.other_mat1);
const MatPtrF grad_v_f(*t.other_mat2);
update_w(weights_f, grad_m_f, grad_v_f);
});
}
} // namespace
void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2,
float epsilon, size_t t, const WeightsOwner& weights,
const WeightsOwner& grad_m, const WeightsOwner& grad_v,
hwy::ThreadPool& pool) {
AdamUpdate(grad.GetF32(), alpha, beta1, beta2, epsilon, t, weights.GetF32(),
grad_m.GetF32(), grad_v.GetF32(), pool);
}
} // namespace gcpp

View File

@ -1,33 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include <stddef.h>
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2,
float epsilon, size_t t, const WeightsOwner& weights,
const WeightsOwner& grad_m, const WeightsOwner& grad_v,
hwy::ThreadPool& pool);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_

View File

@ -1,34 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
#include <stddef.h>
#include <vector>
namespace gcpp {
struct Prompt {
std::vector<int> tokens;
size_t context_size;
std::vector<int> context() const {
return std::vector<int>(tokens.begin(), tokens.begin() + context_size);
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_

View File

@ -1,90 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <stddef.h>
#include <stdio.h>
#include <random>
#include <vector>
#include "backprop/prompt.h"
namespace gcpp {
class PromptSampler {
public:
virtual Prompt Sample(std::mt19937& gen) = 0;
virtual ~PromptSampler() = default;
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
std::vector<Prompt> batch;
batch.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
batch.emplace_back(Sample(gen));
}
return batch;
}
};
class ReverseSequenceSampler : public PromptSampler {
public:
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
: token_dist_(0, 9) {
for (int i = 0; i < length_histo.size(); ++i) {
const int count = length_histo[i];
for (int j = 0; j < count; ++j) {
length_lut_.push_back(i + 1);
}
}
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
}
virtual ~ReverseSequenceSampler() = default;
static constexpr int kReverseToken = 10;
static constexpr int kEndToken = 11;
Prompt Sample(std::mt19937& gen) override {
Prompt prompt;
int len = length_lut_[length_dist_(gen)];
prompt.tokens.resize(2 * len + 2);
prompt.tokens[len] = kReverseToken;
prompt.tokens[2 * len + 1] = kEndToken;
for (size_t i = 0; i < len; ++i) {
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
}
prompt.context_size = len + 1;
return prompt;
}
static void LogPrompt(const Prompt& prompt) {
static const char* kVocab[] = {
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
};
for (int token : prompt.tokens) printf("%s", kVocab[token]);
printf(" [context_size: %zu]\n", prompt.context_size);
}
private:
std::uniform_int_distribution<> token_dist_;
std::uniform_int_distribution<> length_dist_;
std::vector<int> length_lut_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_

View File

@ -1,199 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
#include <stddef.h>
#include <cmath>
#include <complex>
#include <vector>
#include "gtest/gtest.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
template <typename T, typename U>
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
for (size_t r = 0; r < x.Rows(); ++r) {
const T* row = x.Row(r);
std::complex<U>* c_row = c_x.Row(r);
for (size_t c = 0; c < x.Cols(); ++c) {
c_row[c] = std::complex<U>(row[c], 0.0);
}
}
}
template <typename T, typename U>
void Complexify(const LayerWeightsPtrs<T>& w, LayerWeightsPtrs<U>& c_w) {
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
Complexify(w.linear_w, c_w.linear_w);
}
template <typename T, typename U>
void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
const size_t kLayers = w.c_layers.size();
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
Complexify(w.final_norm_scale, c_w.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) {
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
}
}
// Somewhat duplicates `WeightsOwner`, but that has neither double nor
// complex types allowed and it would cause code bloat to add them there.
template <typename T>
class WeightsWrapper {
public:
explicit WeightsWrapper(const ModelConfig& config) : weights_(config) {
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
weights_.AllocateForTest(owners_, pool);
}
const ModelWeightsPtrs<T>& get() const { return weights_; }
ModelWeightsPtrs<T>& get() { return weights_; }
private:
std::vector<MatOwner> owners_;
ModelWeightsPtrs<T> weights_;
};
template <typename T, typename U>
void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
double max_abs_err, double max_rel_err, int line_test,
int line_util) {
// TODO: consider compensated sum.
double sum0 = 0;
double sum1 = 0;
double sum01 = 0;
for (size_t r = 0; r < actual.Rows(); ++r) {
const T* actual_row = actual.Row(r);
const U* expected_row = expected.Row(r);
for (size_t c = 0; c < actual.Cols(); ++c) {
sum0 += actual_row[c] * actual_row[c];
sum1 += expected_row[c] * expected_row[c];
sum01 += actual_row[c] * expected_row[c];
ASSERT_NEAR(
actual_row[c], expected_row[c],
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
<< "test line " << line_test << "test_util.h line " << line_util
<< " r " << r << " c " << c;
}
}
if (sum0 > 1e-16) {
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
ASSERT_NEAR(norm_dot, 1.0, 3e-6)
<< "test line " << line_test << " test_util.h line " << line_util
<< " sum0: " << sum0 << " sum1: " << sum1 << " sum01: " << sum01;
}
}
// Compute gradient with the finite difference method in the complex plane.
// If f : R->R is the tested function and F : C->C is its extension on the
// complex plane so that F is complex differentiable in x, then
//
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
//
// which means that
//
// F'(x) ~= Imag(F(x + ih)) / h
//
// This method is more numerically stable than the real-valued finite difference
// method since we don't need to subtract floating point numbers that are near
// to each other.
template <typename FUNC, typename T, typename U>
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
FUNC func, U step, T max_abs_err, T max_rel_err,
int line_test, int line_util) {
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
const U inv_step = 1.0 / step;
for (size_t r = 0; r < x.Rows(); ++r) {
std::complex<U>* x_row = x.Row(r);
T* exp_row = exp_grad.Row(r);
for (size_t c = 0; c < x.Cols(); ++c) {
const U x0 = std::real(x_row[c]);
const std::complex<U> x1 = std::complex<U>(x0, step);
x_row[c] = x1;
const std::complex<U> f1 = func();
exp_row[c] = std::imag(f1) * inv_step;
x_row[c] = x0;
}
}
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line_test, line_util);
}
template <typename FUNC>
void TestGradient(const MatPtrT<float>& grad, MatPtrT<std::complex<float>>& x,
FUNC func, float max_abs_err, float max_rel_error,
int line_test, int line_util) {
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line_test,
line_util);
}
template <typename FUNC, typename T>
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
FUNC func, T max_abs_err, T max_rel_error, int line_test,
int line_util) {
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line_test,
line_util);
}
template <typename T, typename U, typename FUNC>
void TestGradient(const LayerWeightsPtrs<T>& grad,
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err,
int line_test) {
TestGradient(grad.pre_attention_norm_scale,
c_weights.pre_attention_norm_scale, func, max_err, max_err,
line_test, __LINE__);
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, func,
max_err, max_err, line_test, __LINE__);
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, func, max_err,
max_err, line_test, __LINE__);
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, func,
max_err, max_err, line_test, __LINE__);
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, func, max_err,
max_err, line_test, __LINE__);
TestGradient(grad.linear_w, c_weights.linear_w, func, max_err, max_err,
line_test, __LINE__);
}
template <typename T, typename U, typename FUNC>
void TestGradient(const ModelWeightsPtrs<T>& grad,
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err,
int line_test) {
TestGradient(grad.embedder_input_embedding,
c_weights.embedder_input_embedding, func, 2 * max_err, max_err,
line_test, __LINE__);
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, func, max_err,
max_err, line_test, __LINE__);
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err,
line_test);
}
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_

View File

@ -65,8 +65,8 @@ static constexpr bool kIsTest = false;
template <typename T> // primary, must specialize
struct CompressTraits {};
// Used by backprop/, where weights are currently f32; also MatMul for f32
// weights or activations, if native `ReorderWidenMulAccumulate` is available.
// Used by MatMul for f32 weights or activations, if native
// `ReorderWidenMulAccumulate` is available.
template <>
struct CompressTraits<float> {
using Packed = float;

View File

@ -81,7 +81,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) {
}
});
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
Compress(raw.PackedScale1(), raw.Extents().Area(), ws, compressed.Span(),
/*packed_ofs=*/0, pool);
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
return compressed;
@ -104,7 +104,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
}
});
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
Compress(raw.PackedScale1(), raw.Extents().Area(), ws, compressed.Span(),
/*packed_ofs=*/0, pool);
// Arbitrary value, different from 1, must match `GenerateMat`.
compressed.SetScale(0.6f);

View File

@ -21,9 +21,6 @@
#include <stddef.h>
#include <stdint.h>
#include <complex>
#include <cstdio>
// IWYU pragma: begin_exports
#include "util/basics.h" // BF16
#include "hwy/aligned_allocator.h"
@ -160,24 +157,22 @@ constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
}
// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for
// `WeightsPtrs`. When adding a new type that is supported, also
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 };
// Tensor types for loading weights.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 };
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "c64"};
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16",
"sfp", "nuq", "f64"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {0,
8 * sizeof(float),
8 * sizeof(BF16),
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
8 * sizeof(std::complex<double>)};
static constexpr size_t kTypeBits[] = {
0,
8 * sizeof(float),
8 * sizeof(BF16),
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
};
static inline bool EnumValid(Type type) {
return static_cast<size_t>(type) < kNumTypes;
@ -197,8 +192,6 @@ Type TypeEnum() {
return Type::kNUQ;
} else if constexpr (hwy::IsSame<Packed, double>()) {
return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
return Type::kC64;
} else {
HWY_DASSERT(false);
return Type::kUnknown;

View File

@ -29,7 +29,6 @@ namespace gcpp {
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt);
// Returns the scale value to use for the embedding (basically sqrt model_dim).
// Also used by backprop/.
float EmbeddingScaling(size_t model_dim);
// Returns the scale value to use for the query in the attention computation.

View File

@ -152,13 +152,12 @@ static ModelConfig ConfigGemmaTiny() {
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 32;
config.vocab_size = 32; // at least two f32 vectors
config.seq_len = 32; // optimize_test requires more than 24
config.seq_len = 32;
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
config.num_layers = 2;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
// This is required for optimize_test to pass.
config.att_cap = 50.0f;
config.final_cap = 30.0f;
config.eos_id = 11;
@ -203,7 +202,6 @@ static ModelConfig ConfigGriffin2B() {
}
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
config.use_local_attention = true;
// This is required for optimize_test to pass.
config.final_cap = 0.0f;
return config;
}

View File

@ -162,7 +162,7 @@ enum class Model {
GEMMA2_9B = 3,
GEMMA2_27B,
GRIFFIN_2B,
GEMMA_TINY, // for backprop/ only
GEMMA_TINY, // for testing only
GEMMA2_2B,
// 8 and 9 are obsolete.
PALIGEMMA2_3B_224 = 10,
@ -330,9 +330,8 @@ struct ModelConfig : public IFields {
// from a blob. Also used by `config_converter.py`, which sets sufficient
// fields for `TestEqual` and then calls `OverwriteWithCanonical()`.
ModelConfig() = default;
// For use by `backprop/`, and `model_store.cc` for pre-2025 format after
// deducing the model from tensors plus a user-specified `wrapping` override
// (see `ChooseWrapping`).
// For use by `model_store.cc` for pre-2025 format after deducing the model
// from tensors plus a user-specified `wrapping` override (`ChooseWrapping`).
ModelConfig(Model model, Type weight, PromptWrapping wrapping);
// Parses a string returned by `Specifier()`. Used by the exporter to select
// the model from command line arguments. Do not use this elsewhere - the

View File

@ -230,13 +230,13 @@ class GemmaAttention {
const float mul) {
// qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Packed();
const float* inv_timescale = activations_.inv_timescale.PackedScale1();
bool is_global_layer =
activations_.weights_config.attention_window_sizes[layer] ==
activations_.seq_len;
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations_.weights_config.model)) {
inv_timescale = activations_.inv_timescale_global.Packed();
inv_timescale = activations_.inv_timescale_global.PackedScale1();
}
// PostQKType::Rope
(void)layer;

View File

@ -24,7 +24,6 @@
#include <string.h>
#include <memory>
#include <utility> // std::move
#include <vector>
// Placeholder for internal header, do not modify.
@ -61,16 +60,6 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
reader_.reset();
}
Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer,
MatMulEnv& env)
: env_(env),
model_(config, std::move(tokenizer)),
weights_(config.weight),
chat_template_(model_.Tokenizer(), model_.Config().model) {
HWY_ASSERT(config.weight == Type::kF32);
weights_.AllocateForTest(config, env_.ctx.pools.Pool(0));
}
Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {

View File

@ -107,11 +107,6 @@ class Gemma {
// `env` must remain valid for the lifetime of this Gemma.
Gemma(const LoaderArgs& loader, MatMulEnv& env);
// Only allocates weights, caller is responsible for filling them. Only used
// by `optimize_test.cc`.
// `env` must remain valid for the lifetime of this Gemma.
Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer, MatMulEnv& env);
~Gemma();
MatMulEnv& Env() const { return env_; }

View File

@ -53,10 +53,7 @@ class ModelStore {
// Reads from file(s) or aborts on error. The latter two arguments are only
// used for pre-2025 files.
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
Tristate wrapping = Tristate::kDefault);
// For optimize_test.cc.
ModelStore(const ModelConfig& config, GemmaTokenizer&& tokenizer)
: config_(config), tokenizer_(std::move(tokenizer)) {}
Tristate wrapping = Tristate::kDefault);
~ModelStore();
const ModelConfig& Config() const {

View File

@ -60,8 +60,7 @@ class GemmaTokenizer {
class GemmaChatTemplate {
public:
// No effect if `tokenizer` is unavailable (as happens in optimize_test.cc),
// but then any other method may abort.
// No effect if `tokenizer` is unavailable, but any other method may abort.
GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model);
// Given prompt tokens, this returns the wrapped prompt including BOS and

View File

@ -21,7 +21,6 @@
#include <stdlib.h>
#include <memory>
#include <random>
#include <string>
#include <vector>
@ -37,7 +36,6 @@
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/stats.h"
// TODO: move into foreach_target; this is only used for NUQ Fixup.
#include "compression/compress-inl.h"
@ -297,9 +295,6 @@ void WeightsOwner::AllocatePointer(const ModelConfig& config) {
case Type::kNUQ:
nuq_weights_.reset(new ModelWeightsPtrs<NuqStream>(config));
break;
case Type::kF32:
float_weights_.reset(new ModelWeightsPtrs<float>(config));
break;
case Type::kBF16:
bf16_weights_.reset(new ModelWeightsPtrs<BF16>(config));
break;
@ -308,55 +303,6 @@ void WeightsOwner::AllocatePointer(const ModelConfig& config) {
}
}
// Gemma calls `WeightsOwner::ReadOrAllocate`, but test code instead calls
// `WeightsPtrs::AllocateForTest`, so the implementation is there, and here
// we only type-dispatch.
void WeightsOwner::AllocateForTest(const ModelConfig& config,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.AllocateWeights");
AllocatePointer(config);
CallT([&](const auto& weights) {
weights->AllocateForTest(mat_owners_, pool);
});
}
void WeightsOwner::ZeroInit() {
PROFILER_FUNC;
CallT([](const auto& weights) { weights->ZeroInit(); });
}
void WeightsOwner::RandInit(float stddev, std::mt19937& gen) {
PROFILER_FUNC;
float_weights_->RandInit(stddev, gen);
}
void WeightsOwner::LogWeightStatsF32() {
size_t total_weights = 0;
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
float_weights_->ForEachTensor(
nullptr, nullptr, [&total_weights](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
if (t.mat.Scale() != 1.0f) {
printf("[scale=%f] ", t.mat.Scale());
}
hwy::Stats stats;
const MatPtrT<float> mat_f(t.mat);
for (size_t r = 0; r < t.mat.Rows(); ++r) {
const float* HWY_RESTRICT row = mat_f.Row(r);
for (size_t c = 0; c < t.mat.Cols(); ++c) {
stats.Notify(row[c]);
}
}
printf("%-20s %12zu %13.10f %8.5f %13.10f\n", t.mat.Name(),
t.mat.Rows() * t.mat.Cols(), stats.Min(), stats.Mean(),
stats.Max());
total_weights += t.mat.Rows() * t.mat.Cols();
});
printf("%-20s %12zu\n", "Total", total_weights);
}
void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Fixup");
CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });

View File

@ -22,7 +22,6 @@
#include <complex>
#include <memory>
#include <mutex> // NOLINT
#include <random>
#include <string>
#include <vector>
@ -298,23 +297,6 @@ struct LayerWeightsPtrs {
});
}
void RandInit(float stddev, std::mt19937& gen) {
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
gcpp::RandInit(t.mat, stddev, gen);
});
}
// Allocates memory for all the tensors in the layer. Note that this is slow
// (non-parallel) and only used for a stand-alone layer.
void AllocateForTest(std::vector<MatOwner>& mat_owners) {
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
// `backprop/` does not use row accessors and hence requires kPacked.
mat_owners.push_back(MatOwner());
mat_owners.back().AllocateFor(t.mat, MatPadding::kPacked);
});
}
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
@ -330,8 +312,8 @@ struct LayerWeightsPtrs {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files must have one or the other, and backprop/ allocates both.
HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr());
// Files must have one or the other.
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
// Done if we already read the transposed tensor.
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
@ -373,11 +355,10 @@ struct LayerWeightsPtrs {
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2, and backprop/ allocates both.
// Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file, but backprop/
// allocates both, so we can only rule out both being null.
HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr());
// w is mutually exclusive with w1 and w2 in the file.
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that they are not
// necessarily the same type.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
@ -397,7 +378,7 @@ struct LayerWeightsPtrs {
gating_einsum_w2.SetType(gating_einsum_w.GetType());
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
// Do not invalidate gating_einsum_w: backprop/ calls this repeatedly.
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
}
// For attention, which might not have a w2. Fast, only updates pointers.
@ -405,9 +386,8 @@ struct LayerWeightsPtrs {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// w is mutually exclusive with w1 in the file, but backprop/ allocates
// both, so we can only rule out both being null.
HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr());
// w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that w2 does not exist for
// MHA, and otherwise might not be the same type.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
@ -436,7 +416,7 @@ struct LayerWeightsPtrs {
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
// Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly.
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
};
@ -567,13 +547,6 @@ struct ModelWeightsPtrs {
});
}
void RandInit(float stddev, std::mt19937& gen) {
ForEachTensor(nullptr, nullptr, [stddev, &gen](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
gcpp::RandInit(t.mat, stddev, gen);
});
}
// Copies only the allocated tensors in `*this` from tensors in `other`.
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
@ -584,27 +557,6 @@ struct ModelWeightsPtrs {
});
}
// Instead of reading, only allocates memory for all tensors. Used by
// `optimizer.cc` via the `Gemma` constructor without weights.
void AllocateForTest(std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
// First get a list of all the tensors.
std::vector<MatPtr*> all_mat;
all_mat.reserve(10 * c_layers.size());
ForEachTensor(nullptr, nullptr, [&all_mat](const TensorArgs& t) {
all_mat.push_back(&t.mat);
});
const size_t start = mat_owners.size();
mat_owners.resize(start + all_mat.size());
// Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, all_mat.size(), [&](uint64_t task, size_t /*thread*/) {
// `backprop/` does not use row accessors and hence requires kPacked.
mat_owners[start + task].AllocateFor(*all_mat[task], MatPadding::kPacked);
});
}
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Must be called after reading and
// updating the attention weights.
@ -640,8 +592,6 @@ class WeightsOwner {
return func(sfp_weights_, std::forward<TArgs>(args)...);
} else if (weight_type_ == Type::kNUQ) {
return func(nuq_weights_, std::forward<TArgs>(args)...);
} else if (weight_type_ == Type::kF32) {
return func(float_weights_, std::forward<TArgs>(args)...);
} else if (weight_type_ == Type::kBF16) {
return func(bf16_weights_, std::forward<TArgs>(args)...);
}
@ -653,23 +603,6 @@ class WeightsOwner {
// Adds one blob for each tensor's data and returns all serialized MatPtr.
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
// For backprop/:
// Only allocates; must subsequently call `ZeroInit` or `RandInit`.
void AllocateForTest(const ModelConfig& config, hwy::ThreadPool& pool);
void ZeroInit();
void RandInit(float stddev, std::mt19937& gen); // F32 or F64 only.
void LogWeightStatsF32();
ModelWeightsPtrs<float>* GetF32() const {
HWY_ASSERT(weight_type_ == Type::kF32);
return float_weights_.get();
}
// Usually taken care of by `ReadFromBlobs`, but must also be called by
// `optimize_test, which updates the attention weights from which this copies.
void Fixup(hwy::ThreadPool& pool);
private:
Type weight_type_;
@ -677,8 +610,10 @@ class WeightsOwner {
// of `CallT` so that can be const.
void AllocatePointer(const ModelConfig& config);
// Called by `ReadFromBlobs`.
void Fixup(hwy::ThreadPool& pool);
// Only one is non-null, determined by `weight_type_`.
std::unique_ptr<ModelWeightsPtrs<float>> float_weights_;
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;

View File

@ -1018,14 +1018,14 @@ struct TestShortDotsT {
const PackedSpan<T> v = vectors.Span();
MatStorageT<double> bufs("bufs", num);
double* HWY_RESTRICT buf = bufs.Packed();
double* HWY_RESTRICT buf = bufs.Row(0);
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work);
GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work);
GenerateWellConditionedInputs(num, raw_w.Row(0), rng, w, work);
GenerateWellConditionedInputs(num, raw_v.Row(0), rng, v, work);
const float dot_exact =
ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf);
ExactDot(raw_w.PackedScale1(), raw_v.PackedScale1(), num, buf);
float dots[kVariants];
for (size_t variant = 0; variant < kVariants; ++variant) {
// Here Packed is not always float, so we must not call kDouble.

View File

@ -387,7 +387,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
*/
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
// This overload is called from backprop/ and if kUseHalfRope.
// This overload is called if kUseHalfRope.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {

View File

@ -35,7 +35,7 @@ static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
// Replacing with expf(ln(1E4) * freq_exponents) changes results
// noticeably.
inv_timescale.Packed()[dim] =
inv_timescale.Row(0)[dim] =
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
}
return inv_timescale;

View File

@ -399,7 +399,7 @@ void TestRopeAndMulBy() {
auto random_float = [&r, &gen] { return r(gen); };
for (int i = 0; i < dim_qkv; ++i) {
x.Packed()[i] = random_float();
x.Row(0)[i] = random_float();
}
const float qmul = ChooseQueryScale(config);
@ -417,25 +417,25 @@ void TestRopeAndMulBy() {
// Rope'd Q embeddings
CopyMat(x, qactual);
CopyMat(x, qexpected);
ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv,
inv_timescale.Packed(), pos);
RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual.Packed()[i];
EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual.Row(0)[i];
}
// Rope'd K embeddings
CopyMat(x, kactual);
CopyMat(x, kexpected);
ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv,
inv_timescale.Packed(), pos);
RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual.Packed()[i];
EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual.Row(0)[i];
}
}
}

View File

@ -53,9 +53,7 @@ PYBIND11_MODULE(configs, py_module) {
.value("kF32", Type::kF32)
.value("kBF16", Type::kBF16)
.value("kSFP", Type::kSFP)
.value("kNUQ", Type::kNUQ)
.value("kF64", Type::kF64)
.value("kC64", Type::kC64);
.value("kNUQ", Type::kNUQ);
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
.value("kGemma", LayerAttentionType::kGemma)

View File

@ -18,8 +18,6 @@
#include <stddef.h>
#include <stdint.h>
#include <random>
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/per_target.h" // VectorBytes
@ -62,35 +60,6 @@ void ZeroInit(MatPtr& mat) {
}
}
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
PROFILER_FUNC;
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
// Only generates float/double for use by backprop/.
HWY_ASSERT(mat.GetType() == Type::kF32 || mat.GetType() == Type::kF64);
mat.SetScale(1.0f);
std::normal_distribution<float> dist(0.0, stddev);
if (mat.GetType() == Type::kF32) {
MatPtrT<float> mat_f(mat);
for (size_t r = 0; r < mat.Rows(); ++r) {
float* HWY_RESTRICT row = mat_f.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen);
}
}
} else {
MatPtrT<double> mat_d(mat);
for (size_t r = 0; r < mat.Rows(); ++r) {
double* HWY_RESTRICT row = mat_d.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen);
}
}
}
}
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
size_t line_bytes) {
switch (padding) {

View File

@ -20,7 +20,6 @@
#include <stddef.h>
#include <stdint.h>
#include <random>
#include <string>
// IWYU pragma: begin_exports
@ -229,21 +228,11 @@ class MatPtrT : public MatPtr {
MatPtrT(const MatPtrT& other) = default;
MatPtrT& operator=(const MatPtrT& other) = default;
// Returns the entire tensor for use by `backprop/*`. Verifies layout is
// `kPacked`. Preferably call `Row` instead, which works for either layout.
MatT* Packed() {
HWY_DASSERT_M(IsPacked(), name_.c_str());
return HWY_RCAST_ALIGNED(MatT*, ptr_);
}
const MatT* Packed() const {
HWY_DASSERT_M(IsPacked(), name_.c_str());
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
}
// As `Packed()`, plus checks the scale is 1.0 because callers will ignore it.
// This is typically used for `MatMul` bias vectors and norm weights.
// Returns the entire tensor after checking the scale is 1.0 because callers
// will ignore it. Used for `MatMul` bias vectors and norm weights.
const MatT* PackedScale1() const {
HWY_DASSERT(Scale() == 1.0f);
return Packed();
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
}
MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
@ -268,8 +257,7 @@ class MatPtrT : public MatPtr {
};
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`. This supports all types used as weights, which excludes
// `kC64` and `kF64` (used only in `backprop/`).
// optional `args`. This supports all types used as weights.
template <class Func, typename... Args>
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) {
@ -334,8 +322,6 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
void CopyMat(const MatPtr& from, MatPtr& to);
void ZeroInit(MatPtr& mat);
// F32/F64 only.
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
// Our tensors are always row-major. This enum indicates how much (if any)
// padding comes after each row.
@ -346,9 +332,6 @@ enum class MatPadding {
// `BlobStore`) are not padded, which also extends to memory-mapped tensors.
// However, `BlobStore` is able to insert padding via row-wise I/O when
// reading from disk via `Mode::kRead`.
//
// `backprop/*` also requires this layout because it indexes directly into
// the storage instead of calling `Row()`.
kPacked,
// Enough to round up to an odd number of cache lines, which can reduce
// cache conflict misses or 4K aliasing.
@ -378,9 +361,9 @@ class MatOwner {
AlignedPtr<uint8_t[]> storage_;
};
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
// tests to allocate and access tensors of a known type. By contrast, the
// heterogeneous model weights are owned by vectors of `MatOwner`.
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by tests to allocate
// and access tensors of a known type. By contrast, the heterogeneous model
// weights are owned by vectors of `MatOwner`.
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
@ -393,7 +376,7 @@ class MatStorageT : public MatPtrT<MatT> {
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {}
~MatStorageT() = default;
// Allow move for backprop/activations.
// Allow move for KVCache.
MatStorageT(MatStorageT&&) = default;
MatStorageT& operator=(MatStorageT&&) = default;
@ -401,13 +384,6 @@ class MatStorageT : public MatPtrT<MatT> {
MatOwner owner_;
};
// Helper factory function for use by `backprop/` to avoid specifying the
// `MatPadding` argument everywhere.
template <typename T>
MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
return MatStorageT<T>(name, Extents2D(rows, cols), MatPadding::kPacked);
}
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
// seekable (non-NUQ) T.
#pragma pack(push, 1) // power of two size