diff --git a/BUILD.bazel b/BUILD.bazel index 5502169..fc5c0c5 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -685,131 +685,3 @@ cc_binary( "@nlohmann_json//:json", ], ) - -cc_library( - name = "prompt", - hdrs = ["backprop/prompt.h"], - deps = [], -) - -cc_library( - name = "sampler", - hdrs = ["backprop/sampler.h"], - deps = [ - ":prompt", - ], -) - -cc_library( - name = "backprop", - srcs = [ - "backprop/backward.cc", - "backprop/forward.cc", - ], - hdrs = [ - "backprop/activations.h", - "backprop/backward.h", - "backprop/forward.h", - ], - textual_hdrs = [ - "backprop/backward-inl.h", - "backprop/forward-inl.h", - ], - deps = [ - ":allocator", - ":common", - ":configs", - ":mat", - ":ops", - ":prompt", - ":weights", - "@highway//:dot", - "@highway//:hwy", # base.h - "@highway//:thread_pool", - ], -) - -cc_library( - name = "backprop_scalar", - hdrs = [ - "backprop/activations.h", - "backprop/backward_scalar.h", - "backprop/common_scalar.h", - "backprop/forward_scalar.h", - ], - deps = [ - ":common", - ":configs", - ":mat", - ":prompt", - ":weights", - "@highway//:hwy", - ], -) - -cc_test( - name = "backward_test", - size = "large", - srcs = [ - "backprop/backward_test.cc", - "backprop/test_util.h", - ], - exec_properties = { - # Avoid linker OOMs when building with sanitizer instrumentation. - "mem": "28g", - }, - deps = [ - ":backprop", - ":backprop_scalar", - ":configs", - ":mat", - ":ops", - ":prompt", - ":sampler", - ":threading_context", - ":weights", - "@googletest//:gtest_main", # buildcleaner: keep - "@highway//:hwy", - "@highway//:hwy_test_util", - "@highway//:thread_pool", - ], -) - -cc_library( - name = "optimizer", - srcs = ["backprop/optimizer.cc"], - hdrs = ["backprop/optimizer.h"], - deps = [ - ":allocator", - ":mat", - ":weights", - "@highway//:hwy", - "@highway//:thread_pool", - ], -) - -cc_test( - name = "optimize_test", - srcs = ["backprop/optimize_test.cc"], - exec_properties = { - # Avoid linker OOMs when building with sanitizer instrumentation. - "mem": "28g", - }, - deps = [ - ":allocator", - ":backprop", - ":basics", - ":configs", - ":gemma_lib", - ":ops", - ":optimizer", - ":prompt", - ":sampler", - ":threading", - ":tokenizer", - ":weights", - "@googletest//:gtest_main", # buildcleaner: keep - "//compression:types", - "@highway//:thread_pool", - ], -) diff --git a/CMakeLists.txt b/CMakeLists.txt index 97f4ccb..6f6a8d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,18 +41,6 @@ FetchContent_MakeAvailable(benchmark) # Base source files set(SOURCES - backprop/activations.h - backprop/backward_scalar.h - backprop/backward-inl.h - backprop/backward.cc - backprop/backward.h - backprop/common_scalar.h - backprop/forward_scalar.h - backprop/forward-inl.h - backprop/forward.cc - backprop/forward.h - backprop/optimizer.cc - backprop/optimizer.h compression/compress-inl.h compression/compress.cc compression/compress.h @@ -202,8 +190,6 @@ enable_testing() include(GoogleTest) set(GEMMA_TEST_FILES - backprop/backward_test.cc - backprop/optimize_test.cc compression/compress_test.cc compression/distortion_test.cc compression/nuq_test.cc diff --git a/DEVELOPERS.md b/DEVELOPERS.md index 4248cde..a19bf9c 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -173,9 +173,6 @@ more custom you can call transformer which performs a single inference operation on a single token and mutates the Activations and the KVCache through the neural network computation. -Note that an experimental backward pass is available in backprop/, which may be -useful for fine tuning. - ### For low level operations, defining new architectures, call `ops.h` functions directly You use `ops.h` if you're writing other NN architectures or modifying the diff --git a/backprop/activations.h b/backprop/activations.h deleted file mode 100644 index 7e42032..0000000 --- a/backprop/activations.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ -#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ - -#include - -#include - -#include "gemma/configs.h" // ModelConfig -#include "util/mat.h" // MatStorageT - -namespace gcpp { - -template -struct ForwardLayer { - ForwardLayer(const LayerConfig& config, size_t seq_len) - : input(MakePacked("input", seq_len, config.model_dim)), - pre_att_rms_out( - MakePacked("pre_att_rms_out", seq_len, config.model_dim)), - qkv(MakePacked("qkv", seq_len * (config.heads + 2), config.qkv_dim)), - att(MakePacked("att", seq_len * config.heads, seq_len)), - att_out( - MakePacked("att_out", seq_len * config.heads, config.qkv_dim)), - att_post1(MakePacked("att_post1", seq_len, config.model_dim)), - attention_out( - MakePacked("attention_out", seq_len, config.model_dim)), - pre_ffw_rms_out( - MakePacked("preFF_rms_out", seq_len, config.model_dim)), - ffw_hidden( - MakePacked("ffw_hidden", seq_len, config.ff_hidden_dim * 2)), - ffw_hidden_gated( - MakePacked("ffw_hidden_gated", seq_len, config.ff_hidden_dim)), - layer_config(config) {} - - MatStorageT input; - MatStorageT pre_att_rms_out; - MatStorageT qkv; - MatStorageT att; - MatStorageT att_out; - MatStorageT att_post1; - MatStorageT attention_out; - MatStorageT pre_ffw_rms_out; - MatStorageT ffw_hidden; - MatStorageT ffw_hidden_gated; - const LayerConfig& layer_config; -}; - -template -struct ForwardPass { - ForwardPass(const ModelConfig& config) - : final_layer_output( - MakePacked("fin_layer_out", config.seq_len, config.model_dim)), - final_norm_output( - MakePacked("fin_norm_out", config.seq_len, config.model_dim)), - logits(MakePacked("logits", config.seq_len, config.vocab_size)), - probs(MakePacked("probs", config.seq_len, config.vocab_size)), - weights_config(config) { - for (const auto& layer_config : config.layer_configs) { - layers.emplace_back(layer_config, config.seq_len); - } - } - - std::vector> layers; - MatStorageT final_layer_output; - MatStorageT final_norm_output; - MatStorageT logits; - MatStorageT probs; - const ModelConfig& weights_config; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h deleted file mode 100644 index 783ed0b..0000000 --- a/backprop/backward-inl.h +++ /dev/null @@ -1,407 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Implementation of the Vector-Jacobian Products (VJP) of the individual -// operations of the forward pass. - -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ - -#include - -#include -#include - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/common.h" // EmbeddingScaling -#include "gemma/configs.h" // LayerConfig, ModelConfig -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE -#endif - -#include "hwy/highway.h" -// After highway.h -#include "ops/matmul-inl.h" -#include "ops/ops-inl.h" -#include "hwy/contrib/dot/dot-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - -HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, - const float* HWY_RESTRICT x, // num_tokens * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows - size_t cols, size_t rows, size_t num_tokens, - float* HWY_RESTRICT grad_w, // kRows * kCols, - float* HWY_RESTRICT grad_x, // num_tokens * kCols - hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0])); - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t voffs = pos * rows; - const size_t xoffs = pos * cols; - for (size_t j = 0; j < rows; ++j) { - MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols); - MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols); - } - } -} - -HWY_INLINE void MultiHeadMatMulVJP( - const float* HWY_RESTRICT weights, // heads * kRows * kCols - const float* HWY_RESTRICT x, // num_tokens * heads * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows - size_t heads, size_t cols, size_t rows, size_t num_tokens, - float* HWY_RESTRICT grad_w, // heads * kRows * kCols - float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols - hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0])); - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t j = 0; j < rows; ++j) { - for (size_t h = 0; h < heads; ++h) { - MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols], - &grad_w[h * rows * cols + j * cols], cols); - MulByConstAndAdd(v[pos * rows + j], - &weights[h * rows * cols + j * cols], - &grad_x[pos * heads * cols + h * cols], cols); - } - } - } -} - -template -static HWY_INLINE hn::Vec DGelu(D d, hn::Vec v) { - const hn::Vec kMul = hn::Set(d, 0.044715f); - const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); - const hn::Vec kHalf = hn::Set(d, 0.5f); - const hn::Vec kOne = hn::Set(d, 1.0f); - // kSqrtOverPi*3*kMul - const hn::Vec kMulv2 = hn::Set(d, 0.1070322244f); - - const hn::Vec v2 = hn::Mul(v, v); - const hn::Vec v3 = hn::Mul(v2, v); - const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); - const hn::Vec tanh = hn::Tanh(d, arg); - const hn::Vec cdf = hn::MulAdd(kHalf, tanh, kHalf); - const hn::Vec dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh)); - const hn::Vec darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi); - return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf); -} - -static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward, - float* HWY_RESTRICT backward, - const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D d; - - const auto offset = - hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size)); - hn::Transform1( - d, backward, size, forward, - [&offset](const auto d, const auto v, const auto y) - HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); }); -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP( - const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x, - const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens, - float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x, - hwy::ThreadPool& pool) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * model_dim; - const float ss = detail::RMSNormMul(x + offset, model_dim); - - for (size_t i = 0; i < model_dim; ++i) { - grad_w[i] += v[offset + i] * x[offset + i] * ss; - } - const float ss3 = ss * ss * ss / StaticCast(model_dim); - float tmp = 0.0f; - for (size_t i = 0; i < model_dim; ++i) { - tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i]; - } - tmp *= ss3; - for (size_t i = 0; i < model_dim; ++i) { - grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] - - tmp * x[offset + i]; - } - } -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP( - const float* weights, const std::vector& prompt, const float scaling, - const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) { - HWY_ASSERT(!prompt.empty()); - for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { - int token = prompt[pos]; - MulByConstAndAdd(scaling, v + pos * model_dim, - grad + token * model_dim, model_dim); - } -} - -template -void LayerVJP(const LayerWeightsPtrs& weights, - const ForwardLayer& forward, - const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, - LayerWeightsPtrs& grad, ForwardLayer& backward, - const MatStorageT& inv_timescale, hwy::ThreadPool& pool) { - const LayerConfig& config = weights.layer_config; - const size_t model_dim = config.model_dim; - const size_t qkv_dim = config.qkv_dim; - const size_t heads = config.heads; - const size_t seq_len = forward.input.Rows(); - const size_t ff_hidden_dim = config.ff_hidden_dim; - const float query_scale = - static_cast(1.0 / sqrt(static_cast(qkv_dim))); - HWY_ASSERT(num_tokens <= seq_len); - - MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), - next_layer_grad, ff_hidden_dim, model_dim, num_tokens, - grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t hidden_offset = pos * ff_hidden_dim * 2; - const float* HWY_RESTRICT f_out = - forward.ffw_hidden.Packed() + hidden_offset; - const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim; - const float* HWY_RESTRICT b_out_gated = - backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim; - float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset; - float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim; - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - DF df; - for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) { - const auto y = Load(df, f_out + i); - const auto x = Load(df, f_out_mul + i); - const auto v = Load(df, b_out_gated + i); - hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i); - hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i); - } - } - - MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(), - backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2, - num_tokens, grad.gating_einsum_w.Packed(), - backward.pre_ffw_rms_out.Packed(), pool); - RMSNormVJP(weights.pre_ffw_norm_scale.Packed(), - forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(), - model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(), - backward.attention_out.Packed(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(next_layer_grad + pos * model_dim, - backward.attention_out.Packed() + pos * model_dim, model_dim); - } - - ZeroInit(backward.qkv); - - MultiHeadMatMulVJP( - weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(), - backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens, - grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool); - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset; - const float* HWY_RESTRICT b_att_out = - backward.att_out.Packed() + (pos * heads + head) * qkv_dim; - float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim; - const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs; - float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs; - b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim); - MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim); - } - } - } - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset; - float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset; - SoftmaxVJP(f_head_att, b_head_att, pos + 1); - } - } - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim; - const size_t aoffs = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs; - const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs; - float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim; - const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs; - float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs; - MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim); - MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim); - } - } - } - - for (int pos = 0; pos < static_cast(num_tokens); ++pos) { - float* HWY_RESTRICT b_kv = - backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim; - Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos); - } - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - float* HWY_RESTRICT b_q = - backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim; - MulByConst(query_scale, b_q, qkv_dim); - Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos); - } - } - - MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(), - backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens, - grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(), - pool); - RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(), - backward.pre_att_rms_out.Packed(), model_dim, num_tokens, - grad.pre_attention_norm_scale.Packed(), backward.input.Packed(), - pool); - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(backward.attention_out.Packed() + pos * model_dim, - backward.input.Packed() + pos * model_dim, model_dim); - } -} - -static HWY_NOINLINE void SoftcapVJP(const float cap, - const float* HWY_RESTRICT forward, - float* HWY_RESTRICT backward, - const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D d; - - const auto one = hn::Set(d, 1.0f); - const auto vcap = hn::Set(d, cap); - const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); - hn::Transform1(d, backward, size, forward, - [&](const auto d, const auto v, const auto y) HWY_ATTR { - const auto scaled = hn::Mul(vinv_cap, y); // = tanh - return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); - }); -} - -static HWY_NOINLINE void CrossEntropyLossGrad( - const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, - const Prompt& prompt, size_t vocab_size) { - HWY_ASSERT(!prompt.tokens.empty()); - const float scaling = -1.0 / std::log(2.0); - size_t num_tokens = prompt.tokens.size() - 1; - hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0])); - for (size_t pos = 0; pos < num_tokens; ++pos) { - if (pos + 1 < prompt.context_size) { - continue; - } - const int next_token = prompt.tokens[pos + 1]; - grad[pos * vocab_size + next_token] = - scaling / x[pos * vocab_size + next_token]; - } -} - -template -void CrossEntropyLossBackwardPassInl(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - const ModelConfig& config = weights.weights_config; - const size_t kVocabSize = config.vocab_size; - const size_t model_dim = config.model_dim; - const size_t kLayers = config.layer_configs.size(); - const float kEmbScaling = EmbeddingScaling(model_dim); - HWY_ASSERT(!config.absolute_pe); - HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); - HWY_ASSERT(config.layer_configs[0].kv_heads == 1); - - HWY_DASSERT(prompt.context_size > 0); - HWY_DASSERT(prompt.context_size < prompt.tokens.size()); - const size_t num_tokens = prompt.tokens.size() - 1; - - CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt, - kVocabSize); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize, - backward.logits.Packed() + pos * kVocabSize, kVocabSize); - } - - if (config.final_cap > 0.0f) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize, - backward.logits.Packed() + pos * kVocabSize, kVocabSize); - } - } - - MatMulVJP(weights.embedder_input_embedding.Packed(), - forward.final_norm_output.Packed(), backward.logits.Packed(), - model_dim, kVocabSize, num_tokens, - grad.embedder_input_embedding.Packed(), - backward.final_norm_output.Packed(), pool); - - RMSNormVJP(weights.final_norm_scale.Packed(), - forward.final_layer_output.Packed(), - backward.final_norm_output.Packed(), model_dim, num_tokens, - grad.final_norm_scale.Packed(), - backward.final_layer_output.Packed(), pool); - - for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { - auto layer_config = config.layer_configs[layer]; - // TODO(szabadka) Implement Griffin layer vjp. - HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma); - float* next_layer_grad = layer + 1 < kLayers - ? backward.layers[layer + 1].input.Packed() - : backward.final_layer_output.Packed(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, - num_tokens, *grad.GetLayer(layer), backward.layers[layer], - inv_timescale, pool); - } - - InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens, - kEmbScaling, backward.layers[0].input.Packed(), - grad.embedder_input_embedding.Packed(), model_dim); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT diff --git a/backprop/backward.cc b/backprop/backward.cc deleted file mode 100644 index 36edf74..0000000 --- a/backprop/backward.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/backward.h" - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/weights.h" -#include "util/mat.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep - -#include "hwy/highway.h" -// After highway.h -#include "backprop/backward-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -void CrossEntropyLossBackwardPassT(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward, - inv_timescale, pool); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -HWY_EXPORT(CrossEntropyLossBackwardPassT); - -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( - prompt, weights, forward, grad, backward, inv_timescale, pool); -} - -} // namespace gcpp -#endif // HWY_ONCE diff --git a/backprop/backward.h b/backprop/backward.h deleted file mode 100644 index f8de706..0000000 --- a/backprop/backward.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/weights.h" -#include "util/mat.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h deleted file mode 100644 index b0c7f13..0000000 --- a/backprop/backward_scalar.h +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ - -#include -#include - -#include -#include - -#include "backprop/activations.h" -#include "backprop/common_scalar.h" -#include "backprop/prompt.h" -#include "gemma/common.h" // EmbeddingScaling -#include "gemma/weights.h" - -namespace gcpp { -template -void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, - size_t N, size_t M, size_t K) { - memset(dx, 0, M * K * sizeof(dx[0])); - for (size_t i = 0; i < K; ++i) { - for (size_t j = 0; j < N; ++j) { - MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M); - MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M); - } - } -} -template -void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, - size_t H, size_t N, size_t M, size_t K) { - memset(dx, 0, H * M * K * sizeof(dx[0])); - for (size_t i = 0; i < K; ++i) { - for (size_t j = 0; j < N; ++j) { - for (size_t h = 0; h < H; ++h) { - MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M], - &dw[h * N * M + j * M], M); - MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M], - &dx[i * H * M + h * M], M); - } - } - } -} - -template -void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, - size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - constexpr T eps(1e-6); - T ss = SquaredL2(x + i * N, N); - ss = T(1.0) / std::sqrt(ss / T(N) + eps); - for (size_t j = 0; j < N; ++j) { - dw[j] += dy[i * N + j] * x[i * N + j] * ss; - } - const T ss3 = ss * ss * ss / T(N); - T tmp = 0.0; - for (size_t j = 0; j < N; ++j) { - tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j]; - } - tmp *= ss3; - for (size_t j = 0; j < N; ++j) { - dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j]; - } - } -} -template -void SoftmaxVJPT(const T* y, T* dy, size_t N) { - T sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += y[i] * dy[i]; - } - for (size_t i = 0; i < N; ++i) { - dy[i] = y[i] * (dy[i] - sum); - } -} -template -void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - SoftmaxVJPT(y + i * N, dy + i * N, N); - } -} - -template -T GeluDerivative(T x) { - static const T kMul = 0.044715; - static const T kSqrt2OverPi = 0.797884560804236; - static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul; - - const T x2 = x * x; - const T x3 = x2 * x; - const T arg = kSqrt2OverPi * (kMul * x3 + x); - const T tanh = std::tanh(arg); - const T cdf = T(0.5) * (T(1.0) + tanh); - const T dtanh = T(1.0) - tanh * tanh; - const T darg = kMul2 * x2 + kSqrt2OverPi; - return T(0.5) * x * dtanh * darg + cdf; -} - -template -void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - const T* x1 = in + i * 2 * N; - const T* x2 = x1 + N; - const T* v = d_out + i * N; - T* dx1 = d_in + i * 2 * N; - T* dx2 = dx1 + N; - for (size_t j = 0; j < N; ++j) { - dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]); - dx2[j] = v[j] * Gelu(x1[j]); - } - } -} - -template -void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, - size_t num_tokens, size_t kHeads, size_t qkv_dim, - size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * (kHeads + 2) * qkv_dim; - memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0])); - } - for (size_t head = 0; head < kHeads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim; - const size_t aoffs = head * seq_len + pos * kHeads * seq_len; - const T* q = qkv + qoffs; - const T* dout = doutput + aoffs; - T* dq = dqkv + qoffs; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim; - const T* k = qkv + koffs; - T* dk = dqkv + koffs; - MulByConstAndAddT(dout[pos2], k, dq, qkv_dim); - MulByConstAndAddT(dout[pos2], q, dk, qkv_dim); - } - } - } -} - -template -void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads, - size_t seq_len) { - for (size_t head = 0; head < kHeads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - size_t offset = pos * kHeads * seq_len + head * seq_len; - SoftmaxVJPT(y + offset, dy + offset, pos + 1); - memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); - } - } -} - -template -void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, - T* dqkv, T* dattention, size_t num_tokens, size_t kHeads, - size_t qkv_dim, size_t seq_len) { - auto v_offset = [&](size_t pos) { - return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim; - }; - for (size_t pos = 0; pos < num_tokens; ++pos) { - memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0])); - } - for (size_t head = 0; head < kHeads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim; - const size_t aoffset = head * seq_len + pos * kHeads * seq_len; - const T* att = &attention[aoffset]; - const T* dout = &doutput[offset]; - T* datt = &dattention[aoffset]; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim); - MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim); - } - } - } -} - -template -void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, - const T* dy, T* dw, size_t N) { - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - for (size_t i = 0; i < num_tokens; ++i) { - int token = tokens[i]; - MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N); - } -} - -template -void LayerVJP(const LayerWeightsPtrs& weights, - const ForwardLayer& forward, const T* dy, - LayerWeightsPtrs& grad, ForwardLayer& backward, - size_t num_tokens) { - const LayerConfig& layer_config = weights.layer_config; - const size_t model_dim = layer_config.model_dim; - const size_t seq_len = forward.input.Rows(); - const size_t qkv_dim = layer_config.qkv_dim; - const size_t kHeads = layer_config.heads; - const size_t kFFHiddenDim = layer_config.ff_hidden_dim; - const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim)); - - MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy, - grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), - model_dim, kFFHiddenDim, num_tokens); - - GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(), - backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens); - - MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(), - backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(), - backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim, - num_tokens); - - RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(), - forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(), - grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), - model_dim, num_tokens); - - AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim); - - MultiHeadMatMulVJPT( - weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(), - backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(), - backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens); - - MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(), - backward.att_out.Packed(), backward.qkv.Packed(), - backward.att.Packed(), num_tokens, kHeads, qkv_dim, - seq_len); - - MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens, - kHeads, seq_len); - - MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(), - backward.qkv.Packed(), num_tokens, kHeads, qkv_dim, - seq_len); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim; - MulByConstT(kQueryScale, qkv, kHeads * qkv_dim); - } - - for (int pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim; - for (size_t h = 0; h <= kHeads; ++h) { - Rope(qkv + h * qkv_dim, qkv_dim, -pos); - } - } - - MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(), - backward.qkv.Packed(), grad.qkv_einsum_w.Packed(), - backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim, - model_dim, num_tokens); - RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(), - backward.pre_att_rms_out.Packed(), - grad.pre_attention_norm_scale.Packed(), backward.input.Packed(), - model_dim, num_tokens); - - AddFromT(backward.attention_out.Packed(), backward.input.Packed(), - num_tokens * model_dim); -} - -template -void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) { - const T inv_cap = T{1.0} / static_cast(cap); - for (size_t i = 0; i < N; ++i) { - T scaled = y[i] * inv_cap; // tanh - dy[i] *= (T{1.0} - scaled * scaled); - } -} - -template -void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { - T scaling = -1.0 / std::log(2.0); - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - memset(dx, 0, V * num_tokens * sizeof(x[0])); - for (size_t i = 0; i < num_tokens; ++i) { - if (i + 1 < prompt.context_size) { - continue; - } - const int next_token = tokens[i + 1]; - dx[i * V + next_token] = scaling / x[i * V + next_token]; - } -} - -template -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward) { - const ModelConfig& config = weights.weights_config; - const size_t model_dim = config.model_dim; - const size_t vocab_size = config.vocab_size; - const size_t layers = config.layer_configs.size(); - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - - CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt, - vocab_size); - - SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size, - num_tokens); - - if (config.final_cap > 0.0f) { - for (size_t i = 0; i < num_tokens; ++i) { - SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size, - backward.logits.Packed() + i * vocab_size, vocab_size); - } - } - - MatMulVJPT(weights.embedder_input_embedding.Packed(), - forward.final_norm_output.Packed(), backward.logits.Packed(), - grad.embedder_input_embedding.Packed(), - backward.final_norm_output.Packed(), vocab_size, model_dim, - num_tokens); - - RMSNormVJPT( - weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(), - backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(), - backward.final_layer_output.Packed(), model_dim, num_tokens); - - for (int layer = static_cast(layers) - 1; layer >= 0; --layer) { - T* next_layer_grad = layer + 1 < layers - ? backward.layers[layer + 1].input.Packed() - : backward.final_layer_output.Packed(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, - *grad.GetLayer(layer), backward.layers[layer], num_tokens); - } - - const T kEmbScaling = EmbeddingScaling(model_dim); - InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens, - kEmbScaling, backward.layers[0].input.Packed(), - grad.embedder_input_embedding.Packed(), model_dim); -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc deleted file mode 100644 index 5b220ca..0000000 --- a/backprop/backward_test.cc +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif - -#include - -#include -#include // std::abs -#include -#include - -#include "backprop/activations.h" -#include "backprop/common_scalar.h" // DotT -#include "backprop/forward_scalar.h" // MatMulT -#include "backprop/prompt.h" -#include "backprop/sampler.h" -#include "backprop/test_util.h" -#include "gemma/configs.h" -#include "ops/ops.h" -#include "util/mat.h" -#include "util/threading_context.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -#include "hwy/tests/test_util-inl.h" -// After highway.h -#include "backprop/backward-inl.h" -#include "backprop/forward-inl.h" -#include "ops/ops-inl.h" - -// 'include guard' so we only define this once. Note that HWY_ONCE is only -// defined during the last pass, but this is used in each pass. -#ifndef BACKWARD_TEST_ONCE -#define BACKWARD_TEST_ONCE -// TestEndToEnd is slow, so only run it for the best-available target. -static int run_once; -#endif - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -hwy::ThreadPool& ThreadHostileGetPool() { - // Assume this is only called at the top level, i.e. not in a thread. Then we - // can safely call `SetArgs` only once, because it would assert otherwise. - // This is preferable to calling `ThreadHostileInvalidate`, because we would - // repeat the topology initialization for every test. - if (!ThreadingContext::IsInitialized()) { - gcpp::ThreadingArgs threading_args; - threading_args.max_packages = 1; - threading_args.max_clusters = 8; - threading_args.pin = Tristate::kFalse; - ThreadingContext::SetArgs(threading_args); - } - return ThreadingContext::Get().pools.Pool(); -} - -void TestMatMulVJP() { - static const size_t kRows = 8; - static const size_t kCols = 64; - static const size_t kTokens = 5; - - hwy::ThreadPool& pool = ThreadHostileGetPool(); - std::mt19937 gen(42); - auto weights = MakePacked("weights", kRows, kCols); - auto x = MakePacked("x", kTokens, kCols); - auto dy = MakePacked("dy", kTokens, kRows); - auto grad = MakePacked("grad", kRows, kCols); - auto dx = MakePacked("dx", kTokens, kCols); - using TC = std::complex; - auto c_weights = MakePacked("c_weights", kRows, kCols); - auto c_x = MakePacked("c_x", kTokens, kCols); - auto c_y = MakePacked("c_y", kTokens, kRows); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0f * (1 << iter), gen); - RandInit(x, 1.0f * (1 << iter), gen); - RandInit(dy, 1.0f, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols, - kTokens); - return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows); - }; - - ZeroInit(grad); - MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens, - grad.Packed(), dx.Packed(), pool); - TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__); - TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__); - } -} - -void TestMultiHeadMatMulVJP() { - static const size_t kRows = 2; - static const size_t kCols = 16; - static const size_t kHeads = 4; - static const size_t kTokens = 3; - hwy::ThreadPool& pool = ThreadHostileGetPool(); - std::mt19937 gen(42); - auto weights = MakePacked("weights", kRows, kCols * kHeads); - auto x = MakePacked("x", kTokens, kCols * kHeads); - auto grad = MakePacked("grad", kRows, kCols * kHeads); - auto dx = MakePacked("dx", kTokens, kCols * kHeads); - auto dy = MakePacked("dy", kTokens, kRows); - using TC = std::complex; - auto c_weights = MakePacked("c_weights", kRows, kCols * kHeads); - auto c_x = MakePacked("c_x", kTokens, kCols * kHeads); - auto c_y = MakePacked("c_y", kTokens, kRows); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0f * (1 << iter), gen); - RandInit(x, 1.0f * (1 << iter), gen); - RandInit(dy, 1.0f, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads, - kRows, kCols, kTokens); - return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows); - }; - - ZeroInit(grad); - MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols, - kRows, kTokens, grad.Packed(), dx.Packed(), pool); - TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__); - TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__); - } -} - -void TestRMSNormVJP() { - static const size_t K = 2; - static const size_t N = 64; - hwy::ThreadPool& pool = ThreadHostileGetPool(); - std::mt19937 gen(42); - auto weights = MakePacked("weights", N, 1); - auto x = MakePacked("x", K, N); - auto grad = MakePacked("grad", N, 1); - auto dx = MakePacked("dx", K, N); - auto dy = MakePacked("dy", K, N); - using TC = std::complex; - auto c_weights = MakePacked("c_weights", N, 1); - auto c_x = MakePacked("c_x", K, N); - auto c_y = MakePacked("c_y", K, N); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0f * (1 << iter), gen); - RandInit(x, 1.0f * (1 << iter), gen); - RandInit(dy, 1.0f, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K); - return DotT(dy.Packed(), c_y.Packed(), K * N); - }; - - ZeroInit(grad); - RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(), - dx.Packed(), pool); - TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__); - TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__); - } -} - -void TestEndToEnd() { - if (++run_once > 1) return; // ~3 min on SKX, only run best available target - - std::mt19937 gen(42); - hwy::ThreadPool& pool = ThreadHostileGetPool(); - ModelConfig config(Model::GEMMA_TINY, Type::kF32, PromptWrapping::GEMMA_IT); - WeightsWrapper weights(config); - WeightsWrapper grad(config); - ForwardPass forward0(config); - ForwardPass forward1(config); - ForwardPass backward(config); - using TC = std::complex; - WeightsWrapper c_weights(config); - ForwardPass c_forward(config); - - ReverseSequenceSampler training_task({0, 0, 1, 1}); - std::vector batch = training_task.SampleBatch(3, gen); - - MatStorageT inv_timescale = CreateInvTimescale( - ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim, - config.layer_configs[0].post_qk == PostQKType::HalfRope); - for (const Prompt& prompt : batch) { - ReverseSequenceSampler::LogPrompt(prompt); - weights.get().RandInit(1.0f, gen); - - float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0); - - float loss1 = CrossEntropyLossForwardPass( - prompt.tokens, prompt.context_size, weights.get(), forward1, - inv_timescale, pool); - - EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); - - grad.get().ZeroInit(); - CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), - backward, inv_timescale, pool); - - Complexify(weights.get(), c_weights.get()); - auto func = [&]() { - return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); - }; - - TestGradient(grad.get(), c_weights.get(), func, 2e-3f, __LINE__); - } -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE - -namespace gcpp { -HWY_BEFORE_TEST(BackwardTest); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd); -HWY_AFTER_TEST(); - -} // namespace gcpp - -#endif diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h deleted file mode 100644 index 9794636..0000000 --- a/backprop/common_scalar.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ - -#include - -#include - -#include "util/mat.h" - -namespace gcpp { - -template -U DotT(const T* a, const U* b, size_t N) { - U sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += a[i] * b[i]; - } - return sum; -} - -template<> -inline std::complex DotT(const float* a, const std::complex* b, - size_t N) { - std::complex sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += static_cast(a[i]) * b[i]; - } - return sum; -} - -template -void MulByConstT(T c, T* x, size_t N) { - for (size_t i = 0; i < N; ++i) { - x[i] *= c; - } -} - -// out += c * x -template -void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { - for (size_t i = 0; i < N; ++i) { - out[i] += c * x[i]; - } -} - -template -void MulByConstAndAddT(T c, const MatPtrT& x, MatPtrT& out) { - for (size_t r = 0; r < x.Rows(); ++r) { - MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols()); - } -} - -template -void AddFromT(const T* a, T* out, size_t N) { - for (size_t i = 0; i < N; ++i) { - out[i] += a[i]; - } -} - -template -T SquaredL2(const T* x, size_t N) { - T sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += x[i] * x[i]; - } - return sum; -} - -template -T Gelu(T x) { - static const T kMul = 0.044715; - static const T kSqrt2OverPi = 0.797884560804236; - - const T x3 = x * x * x; - const T arg = kSqrt2OverPi * (kMul * x3 + x); - const T cdf = T(0.5) * (T(1.0) + std::tanh(arg)); - return x * cdf; -} - -template -void Rope(T* x, U base, size_t N, int i) { - const size_t N2 = N / 2; - for (size_t dim = 0; dim < N2; ++dim) { - const T freq_exponents = T(2 * dim) / T(N); - const T timescale = std::pow(base, freq_exponents); - const T theta = T(i) / timescale; - const T cos_val = std::cos(theta); - const T sin_val = std::sin(theta); - const T x0 = x[dim]; - const T x1 = x[dim + N2]; - x[dim] = x0 * cos_val - x1 * sin_val; - x[dim + N2] = x0 * sin_val + x1 * cos_val; - } -} - -template -void Rope(T* x, size_t N, int i) { - Rope(x, T(10000.0), N, i); -} - -template -void Rope(std::complex* x, size_t N, int i) { - Rope(x, T(10000.0), N, i); -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h deleted file mode 100644 index 954243c..0000000 --- a/backprop/forward-inl.h +++ /dev/null @@ -1,297 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ - -#include -#include - -#include -#include - -#include "backprop/activations.h" -#include "gemma/common.h" // EmbeddingScaling -#include "gemma/configs.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "util/mat.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE -#endif - -#include "hwy/highway.h" -// After highway.h -#include "ops/matvec-inl.h" -#include "ops/ops-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -template -void InputEmbedding(const MatPtrT& weights, const std::vector& prompt, - const float scaling, float* HWY_RESTRICT output, - size_t model_dim, size_t vocab_size) { - const hn::ScalableTag df; - HWY_ASSERT(!prompt.empty()); - for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { - int token = prompt[pos]; - const auto span = weights.Span(); - HWY_ASSERT(span.num == model_dim * vocab_size); - DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim, - model_dim); - MulByConst(scaling, output + pos * model_dim, model_dim); - } -} - -template -void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x, - size_t model_dim, size_t num_tokens, - OutT* HWY_RESTRICT output, - hwy::ThreadPool& pool) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * model_dim; - RMSNorm(x + offset, weights, 0, output + offset, model_dim); - } -} - -static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, - const std::vector& prompt, - size_t context_size, - size_t vocab_size, - hwy::ThreadPool& pool) { - HWY_ASSERT(!prompt.empty()); - float loss = 0.0f; - for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { - if (pos + 1 < context_size) { - continue; // next token is part of context, don't try to predict it - } - const int next_token = prompt[pos + 1]; - loss += std::log(probs[pos * vocab_size + next_token]); - } - float scaling = -1.0 / std::log(2.0); - return loss * scaling; -} - -template -void ApplyForwardLayer(const LayerWeightsPtrs& weights, - ForwardLayer& activations, size_t num_tokens, - float* HWY_RESTRICT output, - const MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - const LayerConfig& config = weights.layer_config; - const size_t model_dim = config.model_dim; - const size_t kSeqLen = activations.input.Rows(); - const size_t kQKVDim = config.qkv_dim; - const size_t kHeads = config.heads; - static const float query_scale = - static_cast(1.0 / sqrt(static_cast(kQKVDim))); - HWY_ASSERT(num_tokens <= kSeqLen); - - ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(), - activations.input.Packed(), model_dim, num_tokens, - activations.pre_att_rms_out.Packed(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim, - activations.pre_att_rms_out.Packed() + pos * model_dim, - activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool); - } - const size_t num_tasks = kHeads * num_tokens; - - for (size_t pos = 0; pos < num_tokens; ++pos) { - float* HWY_RESTRICT k = - activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos); - } - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - float* HWY_RESTRICT q = - activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim; - Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos); - MulByConst(query_scale, q, kQKVDim); - }); - - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - const float* HWY_RESTRICT q = - activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim; - float* HWY_RESTRICT head_att = - activations.att.Packed() + (pos * kHeads + head) * kSeqLen; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const float* HWY_RESTRICT k2 = - activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; - const float score = Dot(q, k2, kQKVDim); - head_att[pos2] = score; - } - }); - - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - float* HWY_RESTRICT head_att = - activations.att.Packed() + (pos * kHeads + head) * kSeqLen; - Softmax(head_att, pos + 1); - }); - - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - const float* HWY_RESTRICT head_att = - activations.att.Packed() + (pos * kHeads + head) * kSeqLen; - float* HWY_RESTRICT att_out = - activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim; - hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - float* HWY_RESTRICT v2 = activations.qkv.Packed() + - (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; - MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); - } - }); - - ZeroInit(activations.attention_out); - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim, - kQKVDim, - activations.att_out.Packed() + pos * kHeads * kQKVDim + - head * kQKVDim, - activations.att_post1.Packed() + pos * model_dim, pool); - AddFrom(activations.att_post1.Packed() + pos * model_dim, - activations.attention_out.Packed() + pos * model_dim, model_dim); - } - } - - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.input.Packed() + pos * model_dim, - activations.attention_out.Packed() + pos * model_dim, model_dim); - } - - ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(), - activations.attention_out.Packed(), model_dim, num_tokens, - activations.pre_ffw_rms_out.Packed(), pool); - const size_t kFFHiddenDim = config.ff_hidden_dim; - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, - activations.pre_ffw_rms_out.Packed() + pos * model_dim, - activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool); - } - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t hidden_offset = pos * kFFHiddenDim * 2; - const float* HWY_RESTRICT out = - activations.ffw_hidden.Packed() + hidden_offset; - const float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - float* HWY_RESTRICT out_gated = - activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim; - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - DF df; - for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { - const auto y = hn::Load(df, out + i); - const auto x = hn::Load(df, out_mul + i); - hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i); - } - } - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim, - activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim, - output + pos * model_dim, pool); - } - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.attention_out.Packed() + pos * model_dim, - output + pos * model_dim, model_dim); - } -} - -template -float CrossEntropyLossForwardPass(const std::vector& prompt, - size_t context_size, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - const MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - const ModelConfig& config = weights.weights_config; - const size_t vocab_size = config.vocab_size; - const size_t model_dim = config.model_dim; - const size_t layers = config.layer_configs.size(); - const float emb_scaling = EmbeddingScaling(model_dim); - HWY_ASSERT(!config.absolute_pe); - HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); - HWY_ASSERT(config.layer_configs[0].kv_heads == 1); - - HWY_DASSERT(context_size > 0); - HWY_DASSERT(context_size < prompt.size()); - const size_t num_tokens = prompt.size() - 1; - - InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling, - forward.layers[0].input.Packed(), model_dim, vocab_size); - - for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { - auto type = config.layer_configs[layer].type; - // TODO(szabadka) Implement Griffin layer. - HWY_ASSERT(type == LayerAttentionType::kGemma); - float* HWY_RESTRICT output = layer + 1 < layers - ? forward.layers[layer + 1].input.Packed() - : forward.final_layer_output.Packed(); - ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer], - num_tokens, output, inv_timescale, pool); - } - - ApplyRMSNorm(weights.final_norm_scale.Packed(), - forward.final_layer_output.Packed(), model_dim, num_tokens, - forward.final_norm_output.Packed(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim, - forward.final_norm_output.Packed() + pos * model_dim, - forward.logits.Packed() + pos * vocab_size, pool); - } - - if (config.final_cap > 0.0f) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - LogitsSoftCap(config.final_cap, - forward.logits.Packed() + pos * vocab_size, vocab_size); - } - } - - CopyMat(forward.logits, forward.probs); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size); - } - - return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size, - vocab_size, pool); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT diff --git a/backprop/forward.cc b/backprop/forward.cc deleted file mode 100644 index c31f359..0000000 --- a/backprop/forward.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/forward.h" - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "util/mat.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep - -#include "hwy/highway.h" -// After highway.h -#include "backprop/forward-inl.h" -#include "gemma/weights.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -float CrossEntropyLossForwardPassT(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size, - weights, forward, inv_timescale, pool); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -HWY_EXPORT(CrossEntropyLossForwardPassT); - -float CrossEntropyLossForwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool) { - return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( - prompt, weights, forward, inv_timescale, pool); -} - -} // namespace gcpp -#endif // HWY_ONCE diff --git a/backprop/forward.h b/backprop/forward.h deleted file mode 100644 index 042d40d..0000000 --- a/backprop/forward.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/weights.h" -#include "util/mat.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -float CrossEntropyLossForwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - MatStorageT& inv_timescale, - hwy::ThreadPool& pool); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h deleted file mode 100644 index 45d0f18..0000000 --- a/backprop/forward_scalar.h +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ - -#include -#include - -#include -#include -#include - -#include "backprop/activations.h" -#include "backprop/common_scalar.h" -#include "backprop/prompt.h" -#include "gemma/common.h" // EmbeddingScaling -#include "gemma/weights.h" -#include "hwy/base.h" - -namespace gcpp { - -// w is N x M matrix in row-major order, x is M x K matrix in column-major order -// y = w * x is N x K matrix in column-major order. -template -void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) { - for (size_t i = 0; i < K; ++i) { - for (size_t j = 0; j < N; ++j) { - y[i * N + j] = DotT(&w[j * M], &x[i * M], M); - } - } -} - -// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in -// column-major order and y = w' * x is N x K matrix in column-major order, -// where w' is the rearrangement of w into an N x HM matrix. -template -void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N, - size_t M, size_t K) { - memset(y, 0, N * K * sizeof(y[0])); - for (size_t i = 0; i < K; ++i) { - for (size_t h = 0; h < H; ++h) { - for (size_t j = 0; j < N; ++j) { - y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M); - } - } - } -} - -template -void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) { - constexpr T eps(1e-6); - for (size_t i = 0; i < K; ++i) { - T ss = SquaredL2(x + i * N, N); - ss = T(1.0) / std::sqrt(ss / T(N) + eps); - for (size_t j = 0; j < N; j++) { - out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]); - } - } -} -template -void Softmax(T* x, size_t N) { - T sum = {}; - auto maxreal = std::real(x[0]); - for (size_t i = 1; i < N; ++i) { - if (std::real(x[i]) > maxreal) { - maxreal = std::real(x[i]); - } - } - for (size_t i = 0; i < N; ++i) { - x[i] = std::exp(x[i] - maxreal); - sum += x[i]; - } - T scale = T(1.0) / sum; - for (size_t i = 0; i < N; ++i) { - x[i] *= scale; - } -} -template -void Softmax(T* x, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - Softmax(x + i * N, N); - } -} -template -void Softcap(float cap, T* x, size_t N) { - const T inv_cap = T{1.0} / static_cast(cap); - for (size_t i = 0; i < N; ++i) { - x[i] = static_cast(cap) * std::tanh(x[i] * inv_cap); - } -} - -template -void GatedGelu(const T* in, T* out, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - const T* x1 = in + i * 2 * N; - const T* x2 = x1 + N; - T* y = out + i * N; - for (size_t j = 0; j < N; ++j) { - y[j] = x2[j] * Gelu(x1[j]); - } - } -} - -template -void InputEmbedding(const T* w, const std::vector& tokens, T scaling, - T* y, size_t N) { - HWY_ASSERT(w != nullptr); - HWY_ASSERT(y != nullptr); - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - for (size_t i = 0; i < num_tokens; ++i) { - int token = tokens[i]; - memcpy(y + i * N, w + token * N, N * sizeof(y[0])); - MulByConstT(scaling, y + i * N, N); - } -} - -template -void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads, - size_t qkv_dim, size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < heads; ++head) { - const size_t qoffset = pos * (heads + 2) * qkv_dim; - const size_t aoffset = pos * heads * seq_len + head * seq_len; - const T* q = qkv + qoffset + head * qkv_dim; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim; - output[aoffset + pos2] = DotT(q, k, qkv_dim); - } - } - } -} -template -void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < heads; ++head) { - size_t offset = pos * heads * seq_len + head * seq_len; - Softmax(x + offset, pos + 1); - memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); - } - } -} -template -void MixByAttention(const T* qkv, const T* attention, T* output, - size_t num_tokens, size_t heads, size_t qkv_dim, - size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < heads; ++head) { - const T* att = &attention[pos * heads * seq_len + head * seq_len]; - T* out = &output[head * qkv_dim + pos * heads * qkv_dim]; - memset(out, 0, qkv_dim * sizeof(out[0])); - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim; - const T* v = &qkv[v_offset]; - MulByConstAndAddT(att[pos2], v, out, qkv_dim); - } - } - } -} -template -void ApplyLayer(const LayerWeightsPtrs& weights, - ForwardLayer& activations, size_t num_tokens, T* output) { - const LayerConfig& layer_config = weights.layer_config; - const size_t model_dim = layer_config.model_dim; - const size_t seq_len = activations.input.Rows(); - const size_t qkv_dim = layer_config.qkv_dim; - const size_t heads = layer_config.heads; - const size_t ff_hidden_dim = layer_config.ff_hidden_dim; - static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim)); - - RMSNormT(weights.pre_attention_norm_scale.Packed(), - activations.input.Packed(), activations.pre_att_rms_out.Packed(), - model_dim, num_tokens); - - MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(), - activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim, - num_tokens); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim; - for (size_t h = 0; h <= heads; ++h) { - Rope(qkv + h * qkv_dim, qkv_dim, pos); - } - } - - for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim; - MulByConstT(query_scale, qkv, heads * qkv_dim); - } - - MaskedAttention(activations.qkv.Packed(), activations.att.Packed(), - num_tokens, heads, qkv_dim, seq_len); - - MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len); - - MixByAttention(activations.qkv.Packed(), activations.att.Packed(), - activations.att_out.Packed(), num_tokens, heads, qkv_dim, - seq_len); - - MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(), - activations.att_out.Packed(), - activations.attention_out.Packed(), heads, model_dim, qkv_dim, - num_tokens); - - AddFromT(activations.input.Packed(), activations.attention_out.Packed(), - num_tokens * model_dim); - - RMSNormT(weights.pre_ffw_norm_scale.Packed(), - activations.attention_out.Packed(), - activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens); - - MatMulT(weights.gating_einsum_w.Packed(), - activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(), - ff_hidden_dim * 2, model_dim, num_tokens); - - GatedGelu(activations.ffw_hidden.Packed(), - activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens); - - MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(), - output, model_dim, ff_hidden_dim, num_tokens); - - AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim); -} - -template -T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { - T loss = {}; - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - for (size_t i = 0; i < num_tokens; ++i) { - if (i + 1 < prompt.context_size) { - continue; // next token is part of context, don't try to predict it - } - const int next_token = tokens[i + 1]; - loss += std::log(x[i * V + next_token]); - } - T scaling = -1.0 / std::log(2.0); - return loss * scaling; -} - -template -T CrossEntropyLossForwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward) { - const ModelConfig& config = weights.weights_config; - const size_t model_dim = config.model_dim; - const size_t vocab_size = config.vocab_size; - const size_t layers = config.layer_configs.size(); - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - - const T kEmbScaling = EmbeddingScaling(model_dim); - InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling, - forward.layers[0].input.Packed(), model_dim); - - for (size_t layer = 0; layer < layers; ++layer) { - T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed() - : forward.final_layer_output.Packed(); - ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, - output); - } - - RMSNormT(weights.final_norm_scale.Packed(), - forward.final_layer_output.Packed(), - forward.final_norm_output.Packed(), model_dim, num_tokens); - - MatMulT(weights.embedder_input_embedding.Packed(), - forward.final_norm_output.Packed(), forward.logits.Packed(), - vocab_size, model_dim, num_tokens); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - if (config.final_cap > 0.0f) { - Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size, - vocab_size); - } - } - - CopyMat(forward.logits, forward.probs); - Softmax(forward.probs.Packed(), vocab_size, num_tokens); - - return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size); -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc deleted file mode 100644 index fdf73ec..0000000 --- a/backprop/optimize_test.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "backprop/activations.h" -#include "backprop/backward.h" -#include "backprop/forward.h" -#include "backprop/optimizer.h" -#include "backprop/prompt.h" -#include "backprop/sampler.h" -#include "compression/types.h" -#include "gemma/configs.h" -#include "gemma/gemma.h" -#include "gemma/tokenizer.h" -#include "gemma/weights.h" -#include "ops/ops.h" -#include "util/allocator.h" -#include "util/basics.h" -#include "util/threading.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -TEST(OptimizeTest, GradientDescent) { - gcpp::ThreadingArgs threading_args; - threading_args.max_packages = 1; - threading_args.max_clusters = 1; - threading_args.pin = Tristate::kFalse; - ThreadingContext::SetArgs(threading_args); - MatMulEnv env(ThreadingContext::Get()); - const Allocator& allocator = env.ctx.allocator; - hwy::ThreadPool& pool = env.ctx.pools.Pool(); - std::mt19937 gen(42); - - ModelConfig config(Model::GEMMA_TINY, Type::kF32, - ChooseWrapping(Model::GEMMA_TINY)); - config.eos_id = ReverseSequenceSampler::kEndToken; - - WeightsOwner grad(Type::kF32), grad_m(Type::kF32), grad_v(Type::kF32); - grad.AllocateForTest(config, pool); - grad_m.AllocateForTest(config, pool); - grad_v.AllocateForTest(config, pool); - grad_m.ZeroInit(); - grad_v.ZeroInit(); - ForwardPass forward(config), backward(config); - KVCache kv_cache(config, /*prefill_tbatch_size=*/16); - - MatStorageT inv_timescale = CreateInvTimescale( - allocator, config.layer_configs[0].qkv_dim, - config.layer_configs[0].post_qk == PostQKType::HalfRope); - - Gemma gemma(config, GemmaTokenizer(kMockTokenizer), env); - - const auto generate = [&](const std::vector& prompt) { - std::vector reply; - auto stream_token = [&reply](int token, float) { - reply.push_back(token); - return token != ReverseSequenceSampler::kEndToken; - }; - RuntimeConfig runtime = { - .max_generated_tokens = 16, - .temperature = 1.0f, - .gen = &gen, - .verbosity = 0, - .stream_token = stream_token, - }; - TimingInfo timing_info; - gemma.Generate(runtime, prompt, 0, kv_cache, timing_info); - return reply; - }; - - // Sanity check of reply tokens. - // 1) Its length should be greater than the prompt. - // 2) The prompt should be a prefix of the reply. - auto verify = [&](const Prompt& prompt) { - const std::vector& context = prompt.context(); - std::vector reply = generate(context); - if (reply.size() <= context.size()) return false; - return std::equal(context.begin(), context.end(), reply.begin(), - reply.begin() + context.size()); - }; - - gemma.MutableWeights().RandInit(1.0f, gen); - gemma.MutableWeights().Fixup(pool); - - printf("Initial weights:\n"); - gemma.MutableWeights().LogWeightStatsF32(); - - constexpr size_t kBatchSize = 8; - constexpr float kAlpha = 0.001f; - constexpr float kBeta1 = 0.9f; - constexpr float kBeta2 = 0.999f; - constexpr float kEpsilon = 1e-8f; - - constexpr float kMaxLoss = 20.0f; - - ReverseSequenceSampler training_task({ - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); - size_t steps = 0; - size_t num_ok; - for (; steps < 1000; ++steps) { - std::mt19937 sgen(42); - grad.ZeroInit(); - float total_loss = 0.0f; - num_ok = 0; - for (size_t i = 0; i < kBatchSize; ++i) { - Prompt prompt = training_task.Sample(sgen); - total_loss += CrossEntropyLossForwardPass( - prompt, *gemma.Weights().GetF32(), forward, inv_timescale, pool); - CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward, - *grad.GetF32(), backward, inv_timescale, - pool); - gemma.MutableWeights().Fixup(pool); - num_ok += verify(prompt) ? 1 : 0; - } - total_loss /= kBatchSize; - - AdamUpdate(grad, kAlpha, kBeta1, kBeta2, kEpsilon, steps + 1, - gemma.Weights(), grad_m, grad_v, pool); - printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", - steps, total_loss, num_ok, kBatchSize); - if (steps % 100 == 0) { - printf("Batch gradient:\n"); - grad.LogWeightStatsF32(); - } - if (total_loss < kMaxLoss) break; // Done - } - printf("Num steps: %zu\n", steps); - printf("Final weights:\n"); - gemma.MutableWeights().LogWeightStatsF32(); - EXPECT_LT(steps, 80); - EXPECT_EQ(num_ok, kBatchSize); -} - -} // namespace gcpp diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc deleted file mode 100644 index 93e3efe..0000000 --- a/backprop/optimizer.cc +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/optimizer.h" - -#include - -#include "gemma/weights.h" -#include "util/allocator.h" -#include "util/mat.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -namespace { - -using MatPtrF = MatPtrT; - -// Split into two classes so that ForEachTensor only requires two "other" -// arguments. This is anyway useful for locality, because `grad` only feeds -// into `grad_m` and `grad_v` here. -class AdamUpdateMV { - public: - AdamUpdateMV(float beta1, float beta2, size_t t) - : beta1_(beta1), - beta2_(beta2), - cbeta1_(1.0f - beta1), - cbeta2_(1.0f - beta2), - norm1_(1.0 / (1.0 - std::pow(beta1, t))), - norm2_(1.0 / (1.0 - std::pow(beta2, t))) {} - - void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) { - for (size_t r = 0; r < grad.Rows(); ++r) { - const float* HWY_RESTRICT g = grad.Row(r); - float* HWY_RESTRICT m = grad_m.Row(r); - float* HWY_RESTRICT v = grad_v.Row(r); - for (size_t c = 0; c < grad.Cols(); ++c) { - m[c] *= beta1_; - m[c] += cbeta1_ * g[c]; - v[c] *= beta2_; - v[c] += cbeta2_ * g[c] * g[c]; - } - } - } - - private: - float beta1_; - float beta2_; - float cbeta1_; - float cbeta2_; - float norm1_; - float norm2_; -}; - -// Updates `weights` based on the updated `grad_m` and `grad_v` from above. -class AdamUpdateW { - public: - AdamUpdateW(float alpha, float beta1, float beta2, float epsilon, size_t t) - : alpha_(alpha), - norm1_(1.0 / (1.0 - std::pow(beta1, t))), - norm2_(1.0 / (1.0 - std::pow(beta2, t))), - epsilon_(epsilon) {} - - void operator()(MatPtrF& weights, const MatPtrF& grad_m, - const MatPtrF& grad_v) { - for (size_t r = 0; r < weights.Rows(); ++r) { - float* HWY_RESTRICT w = weights.Row(r); - const float* HWY_RESTRICT m = grad_m.Row(r); - const float* HWY_RESTRICT v = grad_v.Row(r); - for (size_t c = 0; c < weights.Cols(); ++c) { - const float mhat = m[c] * norm1_; - const float vhat = v[c] * norm2_; - w[c] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); - } - } - } - - private: - float alpha_; - float norm1_; - float norm2_; - float epsilon_; -}; - -void AdamUpdate(ModelWeightsPtrs* grad, float alpha, float beta1, - float beta2, float epsilon, size_t t, - ModelWeightsPtrs* weights, - ModelWeightsPtrs* grad_m, - ModelWeightsPtrs* grad_v, hwy::ThreadPool& pool) { - AdamUpdateMV update_mv(beta1, beta2, t); - grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) { - const MatPtrF grad_f(t.mat); - MatPtrF grad_m_f(*t.other_mat1); - MatPtrF grad_v_f(*t.other_mat2); - update_mv(grad_f, grad_m_f, grad_v_f); - }); - - AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t); - weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) { - MatPtrF weights_f(t.mat); - const MatPtrF grad_m_f(*t.other_mat1); - const MatPtrF grad_v_f(*t.other_mat2); - update_w(weights_f, grad_m_f, grad_v_f); - }); -} - -} // namespace - -void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2, - float epsilon, size_t t, const WeightsOwner& weights, - const WeightsOwner& grad_m, const WeightsOwner& grad_v, - hwy::ThreadPool& pool) { - AdamUpdate(grad.GetF32(), alpha, beta1, beta2, epsilon, t, weights.GetF32(), - grad_m.GetF32(), grad_v.GetF32(), pool); -} - -} // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h deleted file mode 100644 index daf2d82..0000000 --- a/backprop/optimizer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ - -#include - -#include "gemma/weights.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2, - float epsilon, size_t t, const WeightsOwner& weights, - const WeightsOwner& grad_m, const WeightsOwner& grad_v, - hwy::ThreadPool& pool); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ diff --git a/backprop/prompt.h b/backprop/prompt.h deleted file mode 100644 index 76acb56..0000000 --- a/backprop/prompt.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ - -#include -#include - -namespace gcpp { - -struct Prompt { - std::vector tokens; - size_t context_size; - std::vector context() const { - return std::vector(tokens.begin(), tokens.begin() + context_size); - } -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ diff --git a/backprop/sampler.h b/backprop/sampler.h deleted file mode 100644 index 17f5762..0000000 --- a/backprop/sampler.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ - -#include -#include - -#include -#include - -#include "backprop/prompt.h" - -namespace gcpp { - -class PromptSampler { - public: - virtual Prompt Sample(std::mt19937& gen) = 0; - virtual ~PromptSampler() = default; - - std::vector SampleBatch(size_t batch_size, std::mt19937& gen) { - std::vector batch; - batch.reserve(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - batch.emplace_back(Sample(gen)); - } - return batch; - } -}; - -class ReverseSequenceSampler : public PromptSampler { - public: - explicit ReverseSequenceSampler(const std::vector& length_histo) - : token_dist_(0, 9) { - for (int i = 0; i < length_histo.size(); ++i) { - const int count = length_histo[i]; - for (int j = 0; j < count; ++j) { - length_lut_.push_back(i + 1); - } - } - length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1); - } - virtual ~ReverseSequenceSampler() = default; - - static constexpr int kReverseToken = 10; - static constexpr int kEndToken = 11; - - Prompt Sample(std::mt19937& gen) override { - Prompt prompt; - int len = length_lut_[length_dist_(gen)]; - prompt.tokens.resize(2 * len + 2); - prompt.tokens[len] = kReverseToken; - prompt.tokens[2 * len + 1] = kEndToken; - for (size_t i = 0; i < len; ++i) { - prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen); - } - prompt.context_size = len + 1; - return prompt; - } - - static void LogPrompt(const Prompt& prompt) { - static const char* kVocab[] = { - "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|", - }; - for (int token : prompt.tokens) printf("%s", kVocab[token]); - printf(" [context_size: %zu]\n", prompt.context_size); - } - - private: - std::uniform_int_distribution<> token_dist_; - std::uniform_int_distribution<> length_dist_; - std::vector length_lut_; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ diff --git a/backprop/test_util.h b/backprop/test_util.h deleted file mode 100644 index 8f32cbf..0000000 --- a/backprop/test_util.h +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ - -#include - -#include -#include -#include - -#include "gtest/gtest.h" -#include "gemma/configs.h" -#include "gemma/weights.h" -#include "util/mat.h" -#include "util/threading_context.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -template -void Complexify(const MatPtrT& x, MatPtrT>& c_x) { - for (size_t r = 0; r < x.Rows(); ++r) { - const T* row = x.Row(r); - std::complex* c_row = c_x.Row(r); - for (size_t c = 0; c < x.Cols(); ++c) { - c_row[c] = std::complex(row[c], 0.0); - } - } -} - -template -void Complexify(const LayerWeightsPtrs& w, LayerWeightsPtrs& c_w) { - Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); - Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); - Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); - Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale); - Complexify(w.gating_einsum_w, c_w.gating_einsum_w); - Complexify(w.linear_w, c_w.linear_w); -} - -template -void Complexify(const ModelWeightsPtrs& w, ModelWeightsPtrs& c_w) { - const size_t kLayers = w.c_layers.size(); - Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); - Complexify(w.final_norm_scale, c_w.final_norm_scale); - for (size_t i = 0; i < kLayers; ++i) { - Complexify(*w.GetLayer(i), *c_w.GetLayer(i)); - } -} - -// Somewhat duplicates `WeightsOwner`, but that has neither double nor -// complex types allowed and it would cause code bloat to add them there. -template -class WeightsWrapper { - public: - explicit WeightsWrapper(const ModelConfig& config) : weights_(config) { - hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool(); - weights_.AllocateForTest(owners_, pool); - } - - const ModelWeightsPtrs& get() const { return weights_; } - ModelWeightsPtrs& get() { return weights_; } - - private: - std::vector owners_; - ModelWeightsPtrs weights_; -}; - -template -void TestNear(const MatPtrT& actual, const MatPtrT& expected, - double max_abs_err, double max_rel_err, int line_test, - int line_util) { - // TODO: consider compensated sum. - double sum0 = 0; - double sum1 = 0; - double sum01 = 0; - for (size_t r = 0; r < actual.Rows(); ++r) { - const T* actual_row = actual.Row(r); - const U* expected_row = expected.Row(r); - for (size_t c = 0; c < actual.Cols(); ++c) { - sum0 += actual_row[c] * actual_row[c]; - sum1 += expected_row[c] * expected_row[c]; - sum01 += actual_row[c] * expected_row[c]; - ASSERT_NEAR( - actual_row[c], expected_row[c], - std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err)) - << "test line " << line_test << "test_util.h line " << line_util - << " r " << r << " c " << c; - } - } - if (sum0 > 1e-16) { - double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1); - ASSERT_NEAR(norm_dot, 1.0, 3e-6) - << "test line " << line_test << " test_util.h line " << line_util - << " sum0: " << sum0 << " sum1: " << sum1 << " sum01: " << sum01; - } -} - -// Compute gradient with the finite difference method in the complex plane. -// If f : R->R is the tested function and F : C->C is its extension on the -// complex plane so that F is complex differentiable in x, then -// -// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x) -// -// which means that -// -// F'(x) ~= Imag(F(x + ih)) / h -// -// This method is more numerically stable than the real-valued finite difference -// method since we don't need to subtract floating point numbers that are near -// to each other. -template -void TestGradient(const MatPtrT& grad, MatPtrT>& x, - FUNC func, U step, T max_abs_err, T max_rel_err, - int line_test, int line_util) { - MatStorageT exp_grad = MakePacked("exp_grad", x.Rows(), x.Cols()); - const U inv_step = 1.0 / step; - for (size_t r = 0; r < x.Rows(); ++r) { - std::complex* x_row = x.Row(r); - T* exp_row = exp_grad.Row(r); - for (size_t c = 0; c < x.Cols(); ++c) { - const U x0 = std::real(x_row[c]); - const std::complex x1 = std::complex(x0, step); - x_row[c] = x1; - const std::complex f1 = func(); - exp_row[c] = std::imag(f1) * inv_step; - x_row[c] = x0; - } - } - TestNear(grad, exp_grad, max_abs_err, max_rel_err, line_test, line_util); -} - -template -void TestGradient(const MatPtrT& grad, MatPtrT>& x, - FUNC func, float max_abs_err, float max_rel_error, - int line_test, int line_util) { - TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line_test, - line_util); -} - -template -void TestGradient(const MatPtrT& grad, MatPtrT>& x, - FUNC func, T max_abs_err, T max_rel_error, int line_test, - int line_util) { - TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line_test, - line_util); -} - -template -void TestGradient(const LayerWeightsPtrs& grad, - LayerWeightsPtrs& c_weights, FUNC func, T max_err, - int line_test) { - TestGradient(grad.pre_attention_norm_scale, - c_weights.pre_attention_norm_scale, func, max_err, max_err, - line_test, __LINE__); - TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, func, - max_err, max_err, line_test, __LINE__); - TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, func, max_err, - max_err, line_test, __LINE__); - TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, func, - max_err, max_err, line_test, __LINE__); - TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, func, max_err, - max_err, line_test, __LINE__); - TestGradient(grad.linear_w, c_weights.linear_w, func, max_err, max_err, - line_test, __LINE__); -} - -template -void TestGradient(const ModelWeightsPtrs& grad, - ModelWeightsPtrs& c_weights, FUNC func, T max_err, - int line_test) { - TestGradient(grad.embedder_input_embedding, - c_weights.embedder_input_embedding, func, 2 * max_err, max_err, - line_test, __LINE__); - TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, func, max_err, - max_err, line_test, __LINE__); - for (size_t i = 0; i < grad.c_layers.size(); ++i) { - TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err, - line_test); - } -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ diff --git a/compression/compress-inl.h b/compression/compress-inl.h index f9a9a67..57f0979 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -65,8 +65,8 @@ static constexpr bool kIsTest = false; template // primary, must specialize struct CompressTraits {}; -// Used by backprop/, where weights are currently f32; also MatMul for f32 -// weights or activations, if native `ReorderWidenMulAccumulate` is available. +// Used by MatMul for f32 weights or activations, if native +// `ReorderWidenMulAccumulate` is available. template <> struct CompressTraits { using Packed = float; diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index f7a887f..d10625c 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -81,7 +81,7 @@ MatStorageT GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) { } }); - Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(), + Compress(raw.PackedScale1(), raw.Extents().Area(), ws, compressed.Span(), /*packed_ofs=*/0, pool); compressed.SetScale(0.6f); // Arbitrary value, different from 1. return compressed; @@ -104,7 +104,7 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, } }); - Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(), + Compress(raw.PackedScale1(), raw.Extents().Area(), ws, compressed.Span(), /*packed_ofs=*/0, pool); // Arbitrary value, different from 1, must match `GenerateMat`. compressed.SetScale(0.6f); diff --git a/compression/types.h b/compression/types.h index 2865f8b..a699b8f 100644 --- a/compression/types.h +++ b/compression/types.h @@ -21,9 +21,6 @@ #include #include -#include -#include - // IWYU pragma: begin_exports #include "util/basics.h" // BF16 #include "hwy/aligned_allocator.h" @@ -160,24 +157,22 @@ constexpr bool IsNuqStream() { return hwy::IsSame, NuqStream>(); } -// Tensor types for loading weights. Note that not all types are supported as -// weights for a model, but can be used for other purposes, such as types for -// `WeightsPtrs`. When adding a new type that is supported, also -// update gemma.cc, weights.*, and add instantiations/new_one.cc. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 }; +// Tensor types for loading weights. +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "c64"}; +static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", + "sfp", "nuq", "f64"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); -static constexpr size_t kTypeBits[] = {0, - 8 * sizeof(float), - 8 * sizeof(BF16), - 8 * sizeof(SfpStream), - 4 /* NuqStream, actually 4.5 */, - 8 * sizeof(double), - 8 * sizeof(std::complex)}; +static constexpr size_t kTypeBits[] = { + 0, + 8 * sizeof(float), + 8 * sizeof(BF16), + 8 * sizeof(SfpStream), + 4 /* NuqStream, actually 4.5 */, + 8 * sizeof(double), +}; static inline bool EnumValid(Type type) { return static_cast(type) < kNumTypes; @@ -197,8 +192,6 @@ Type TypeEnum() { return Type::kNUQ; } else if constexpr (hwy::IsSame()) { return Type::kF64; - } else if constexpr (hwy::IsSame>()) { - return Type::kC64; } else { HWY_DASSERT(false); return Type::kUnknown; diff --git a/gemma/common.h b/gemma/common.h index a71b9fb..934c6a7 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -29,7 +29,6 @@ namespace gcpp { void Wrap(const ModelConfig& config, size_t pos, std::string& prompt); // Returns the scale value to use for the embedding (basically sqrt model_dim). -// Also used by backprop/. float EmbeddingScaling(size_t model_dim); // Returns the scale value to use for the query in the attention computation. diff --git a/gemma/configs.cc b/gemma/configs.cc index 82a1fa9..613b31a 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -152,13 +152,12 @@ static ModelConfig ConfigGemmaTiny() { config.wrapping = PromptWrapping::GEMMA_IT; config.model_dim = 32; config.vocab_size = 32; // at least two f32 vectors - config.seq_len = 32; // optimize_test requires more than 24 + config.seq_len = 32; LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); config.num_layers = 2; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); - // This is required for optimize_test to pass. config.att_cap = 50.0f; config.final_cap = 30.0f; config.eos_id = 11; @@ -203,7 +202,6 @@ static ModelConfig ConfigGriffin2B() { } config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len); config.use_local_attention = true; - // This is required for optimize_test to pass. config.final_cap = 0.0f; return config; } diff --git a/gemma/configs.h b/gemma/configs.h index 0e3bb0a..caffc81 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -162,7 +162,7 @@ enum class Model { GEMMA2_9B = 3, GEMMA2_27B, GRIFFIN_2B, - GEMMA_TINY, // for backprop/ only + GEMMA_TINY, // for testing only GEMMA2_2B, // 8 and 9 are obsolete. PALIGEMMA2_3B_224 = 10, @@ -330,9 +330,8 @@ struct ModelConfig : public IFields { // from a blob. Also used by `config_converter.py`, which sets sufficient // fields for `TestEqual` and then calls `OverwriteWithCanonical()`. ModelConfig() = default; - // For use by `backprop/`, and `model_store.cc` for pre-2025 format after - // deducing the model from tensors plus a user-specified `wrapping` override - // (see `ChooseWrapping`). + // For use by `model_store.cc` for pre-2025 format after deducing the model + // from tensors plus a user-specified `wrapping` override (`ChooseWrapping`). ModelConfig(Model model, Type weight, PromptWrapping wrapping); // Parses a string returned by `Specifier()`. Used by the exporter to select // the model from command line arguments. Do not use this elsewhere - the diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 4e7912f..6a2f888 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -230,13 +230,13 @@ class GemmaAttention { const float mul) { // qk is either q or k, so qkv_dim is the length we operate on. const size_t qkv_dim = layer_config_.qkv_dim; - const float* inv_timescale = activations_.inv_timescale.Packed(); + const float* inv_timescale = activations_.inv_timescale.PackedScale1(); bool is_global_layer = activations_.weights_config.attention_window_sizes[layer] == activations_.seq_len; // TODO: add a config flag instead of hardcoding the model. if (is_global_layer && IsVLM(activations_.weights_config.model)) { - inv_timescale = activations_.inv_timescale_global.Packed(); + inv_timescale = activations_.inv_timescale_global.PackedScale1(); } // PostQKType::Rope (void)layer; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 81b27d2..cea1559 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -24,7 +24,6 @@ #include #include -#include // std::move #include // Placeholder for internal header, do not modify. @@ -61,16 +60,6 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env) reader_.reset(); } -Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer, - MatMulEnv& env) - : env_(env), - model_(config, std::move(tokenizer)), - weights_(config.weight), - chat_template_(model_.Tokenizer(), model_.Config().model) { - HWY_ASSERT(config.weight == Type::kF32); - weights_.AllocateForTest(config, env_.ctx.pools.Pool(0)); -} - Gemma::~Gemma() = default; void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const { diff --git a/gemma/gemma.h b/gemma/gemma.h index cf027d1..908346f 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -107,11 +107,6 @@ class Gemma { // `env` must remain valid for the lifetime of this Gemma. Gemma(const LoaderArgs& loader, MatMulEnv& env); - // Only allocates weights, caller is responsible for filling them. Only used - // by `optimize_test.cc`. - // `env` must remain valid for the lifetime of this Gemma. - Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer, MatMulEnv& env); - ~Gemma(); MatMulEnv& Env() const { return env_; } diff --git a/gemma/model_store.h b/gemma/model_store.h index 88b536c..0c0803a 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -53,10 +53,7 @@ class ModelStore { // Reads from file(s) or aborts on error. The latter two arguments are only // used for pre-2025 files. ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(), - Tristate wrapping = Tristate::kDefault); - // For optimize_test.cc. - ModelStore(const ModelConfig& config, GemmaTokenizer&& tokenizer) - : config_(config), tokenizer_(std::move(tokenizer)) {} + Tristate wrapping = Tristate::kDefault); ~ModelStore(); const ModelConfig& Config() const { diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index 9e921c1..a8c081b 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -60,8 +60,7 @@ class GemmaTokenizer { class GemmaChatTemplate { public: - // No effect if `tokenizer` is unavailable (as happens in optimize_test.cc), - // but then any other method may abort. + // No effect if `tokenizer` is unavailable, but any other method may abort. GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model); // Given prompt tokens, this returns the wrapped prompt including BOS and diff --git a/gemma/weights.cc b/gemma/weights.cc index 509a0ee..9e926a0 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -21,7 +21,6 @@ #include #include -#include #include #include @@ -37,7 +36,6 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/profiler.h" -#include "hwy/stats.h" // TODO: move into foreach_target; this is only used for NUQ Fixup. #include "compression/compress-inl.h" @@ -297,9 +295,6 @@ void WeightsOwner::AllocatePointer(const ModelConfig& config) { case Type::kNUQ: nuq_weights_.reset(new ModelWeightsPtrs(config)); break; - case Type::kF32: - float_weights_.reset(new ModelWeightsPtrs(config)); - break; case Type::kBF16: bf16_weights_.reset(new ModelWeightsPtrs(config)); break; @@ -308,55 +303,6 @@ void WeightsOwner::AllocatePointer(const ModelConfig& config) { } } -// Gemma calls `WeightsOwner::ReadOrAllocate`, but test code instead calls -// `WeightsPtrs::AllocateForTest`, so the implementation is there, and here -// we only type-dispatch. -void WeightsOwner::AllocateForTest(const ModelConfig& config, - hwy::ThreadPool& pool) { - PROFILER_ZONE("Startup.AllocateWeights"); - - AllocatePointer(config); - CallT([&](const auto& weights) { - weights->AllocateForTest(mat_owners_, pool); - }); -} - -void WeightsOwner::ZeroInit() { - PROFILER_FUNC; - CallT([](const auto& weights) { weights->ZeroInit(); }); -} - -void WeightsOwner::RandInit(float stddev, std::mt19937& gen) { - PROFILER_FUNC; - float_weights_->RandInit(stddev, gen); -} - -void WeightsOwner::LogWeightStatsF32() { - size_t total_weights = 0; - HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights. - float_weights_->ForEachTensor( - nullptr, nullptr, [&total_weights](const TensorArgs& t) { - if (!t.mat.HasPtr()) return; - if (t.mat.Scale() != 1.0f) { - printf("[scale=%f] ", t.mat.Scale()); - } - hwy::Stats stats; - const MatPtrT mat_f(t.mat); - for (size_t r = 0; r < t.mat.Rows(); ++r) { - const float* HWY_RESTRICT row = mat_f.Row(r); - for (size_t c = 0; c < t.mat.Cols(); ++c) { - stats.Notify(row[c]); - } - } - printf("%-20s %12zu %13.10f %8.5f %13.10f\n", t.mat.Name(), - t.mat.Rows() * t.mat.Cols(), stats.Min(), stats.Mean(), - stats.Max()); - - total_weights += t.mat.Rows() * t.mat.Cols(); - }); - printf("%-20s %12zu\n", "Total", total_weights); -} - void WeightsOwner::Fixup(hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.Fixup"); CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); }); diff --git a/gemma/weights.h b/gemma/weights.h index 7f6a12c..d045404 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -22,7 +22,6 @@ #include #include #include // NOLINT -#include #include #include @@ -298,23 +297,6 @@ struct LayerWeightsPtrs { }); } - void RandInit(float stddev, std::mt19937& gen) { - ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - if (!t.mat.HasPtr()) return; - gcpp::RandInit(t.mat, stddev, gen); - }); - } - - // Allocates memory for all the tensors in the layer. Note that this is slow - // (non-parallel) and only used for a stand-alone layer. - void AllocateForTest(std::vector& mat_owners) { - ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - // `backprop/` does not use row accessors and hence requires kPacked. - mat_owners.push_back(MatOwner()); - mat_owners.back().AllocateFor(t.mat, MatPadding::kPacked); - }); - } - // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. @@ -330,8 +312,8 @@ struct LayerWeightsPtrs { // We only use this tensor for Gemma layers. if (layer_config.type != LayerAttentionType::kGemma) return; - // Files must have one or the other, and backprop/ allocates both. - HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr()); + // Files must have one or the other. + HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr()); // Done if we already read the transposed tensor. if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return; @@ -373,11 +355,10 @@ struct LayerWeightsPtrs { // Used for Gemma and Griffin layers; FFWVit uses different tensors. if (layer_config.type == LayerAttentionType::kVit) return; - // Files have both or neither of w1 and w2, and backprop/ allocates both. + // Files have both or neither of w1 and w2. HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); - // w is mutually exclusive with w1 and w2 in the file, but backprop/ - // allocates both, so we can only rule out both being null. - HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr()); + // w is mutually exclusive with w1 and w2 in the file. + HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr()); // Done if we already read split tensors. Note that they are not // necessarily the same type. if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; @@ -397,7 +378,7 @@ struct LayerWeightsPtrs { gating_einsum_w2.SetType(gating_einsum_w.GetType()); gating_einsum_w1.SetScale(gating_einsum_w.Scale()); gating_einsum_w2.SetScale(gating_einsum_w.Scale()); - // Do not invalidate gating_einsum_w: backprop/ calls this repeatedly. + gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols()); } // For attention, which might not have a w2. Fast, only updates pointers. @@ -405,9 +386,8 @@ struct LayerWeightsPtrs { // We only use this tensor for Gemma layers. if (layer_config.type != LayerAttentionType::kGemma) return; - // w is mutually exclusive with w1 in the file, but backprop/ allocates - // both, so we can only rule out both being null. - HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr()); + // w is mutually exclusive with w1 in the file. + HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); // Done if we already read split tensors. Note that w2 does not exist for // MHA, and otherwise might not be the same type. if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; @@ -436,7 +416,7 @@ struct LayerWeightsPtrs { qkv_einsum_w2.SetType(qkv_einsum_w.GetType()); qkv_einsum_w1.SetScale(qkv_einsum_w.Scale()); qkv_einsum_w2.SetScale(qkv_einsum_w.Scale()); - // Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly. + qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); } }; @@ -567,13 +547,6 @@ struct ModelWeightsPtrs { }); } - void RandInit(float stddev, std::mt19937& gen) { - ForEachTensor(nullptr, nullptr, [stddev, &gen](const TensorArgs& t) { - if (!t.mat.HasPtr()) return; - gcpp::RandInit(t.mat, stddev, gen); - }); - } - // Copies only the allocated tensors in `*this` from tensors in `other`. void CopyFrom(const ModelWeightsPtrs& other) { ForEachTensor(const_cast*>(&other), nullptr, @@ -584,27 +557,6 @@ struct ModelWeightsPtrs { }); } - // Instead of reading, only allocates memory for all tensors. Used by - // `optimizer.cc` via the `Gemma` constructor without weights. - void AllocateForTest(std::vector& mat_owners, - hwy::ThreadPool& pool) { - // First get a list of all the tensors. - std::vector all_mat; - all_mat.reserve(10 * c_layers.size()); - ForEachTensor(nullptr, nullptr, [&all_mat](const TensorArgs& t) { - all_mat.push_back(&t.mat); - }); - - const size_t start = mat_owners.size(); - mat_owners.resize(start + all_mat.size()); - - // Allocate in parallel because faulting in large tensors is slow. - pool.Run(0, all_mat.size(), [&](uint64_t task, size_t /*thread*/) { - // `backprop/` does not use row accessors and hence requires kPacked. - mat_owners[start + task].AllocateFor(*all_mat[task], MatPadding::kPacked); - }); - } - // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Must be called after reading and // updating the attention weights. @@ -640,8 +592,6 @@ class WeightsOwner { return func(sfp_weights_, std::forward(args)...); } else if (weight_type_ == Type::kNUQ) { return func(nuq_weights_, std::forward(args)...); - } else if (weight_type_ == Type::kF32) { - return func(float_weights_, std::forward(args)...); } else if (weight_type_ == Type::kBF16) { return func(bf16_weights_, std::forward(args)...); } @@ -653,23 +603,6 @@ class WeightsOwner { // Adds one blob for each tensor's data and returns all serialized MatPtr. std::vector AddTensorDataToWriter(BlobWriter& writer) const; - // For backprop/: - - // Only allocates; must subsequently call `ZeroInit` or `RandInit`. - void AllocateForTest(const ModelConfig& config, hwy::ThreadPool& pool); - void ZeroInit(); - void RandInit(float stddev, std::mt19937& gen); // F32 or F64 only. - void LogWeightStatsF32(); - - ModelWeightsPtrs* GetF32() const { - HWY_ASSERT(weight_type_ == Type::kF32); - return float_weights_.get(); - } - - // Usually taken care of by `ReadFromBlobs`, but must also be called by - // `optimize_test, which updates the attention weights from which this copies. - void Fixup(hwy::ThreadPool& pool); - private: Type weight_type_; @@ -677,8 +610,10 @@ class WeightsOwner { // of `CallT` so that can be const. void AllocatePointer(const ModelConfig& config); + // Called by `ReadFromBlobs`. + void Fixup(hwy::ThreadPool& pool); + // Only one is non-null, determined by `weight_type_`. - std::unique_ptr> float_weights_; std::unique_ptr> bf16_weights_; std::unique_ptr> sfp_weights_; std::unique_ptr> nuq_weights_; diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 731bb01..e089e64 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1018,14 +1018,14 @@ struct TestShortDotsT { const PackedSpan v = vectors.Span(); MatStorageT bufs("bufs", num); - double* HWY_RESTRICT buf = bufs.Packed(); + double* HWY_RESTRICT buf = bufs.Row(0); for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { - GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work); - GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work); + GenerateWellConditionedInputs(num, raw_w.Row(0), rng, w, work); + GenerateWellConditionedInputs(num, raw_v.Row(0), rng, v, work); const float dot_exact = - ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf); + ExactDot(raw_w.PackedScale1(), raw_v.PackedScale1(), num, buf); float dots[kVariants]; for (size_t variant = 0; variant < kVariants; ++variant) { // Here Packed is not always float, so we must not call kDouble. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 9308685..ea27c68 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -387,7 +387,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( */ // `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. -// This overload is called from backprop/ and if kUseHalfRope. +// This overload is called if kUseHalfRope. static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, int pos) { diff --git a/ops/ops.h b/ops/ops.h index 19f6daf..91b4100 100644 --- a/ops/ops.h +++ b/ops/ops.h @@ -35,7 +35,7 @@ static inline HWY_MAYBE_UNUSED MatStorageT CreateInvTimescale( static_cast(2 * dim) / static_cast(rope_dim); // Replacing with expf(ln(1E4) * freq_exponents) changes results // noticeably. - inv_timescale.Packed()[dim] = + inv_timescale.Row(0)[dim] = static_cast(1.0 / std::pow(base_frequency, freq_exponents)); } return inv_timescale; diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 780eb61..2fd3236 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -399,7 +399,7 @@ void TestRopeAndMulBy() { auto random_float = [&r, &gen] { return r(gen); }; for (int i = 0; i < dim_qkv; ++i) { - x.Packed()[i] = random_float(); + x.Row(0)[i] = random_float(); } const float qmul = ChooseQueryScale(config); @@ -417,25 +417,25 @@ void TestRopeAndMulBy() { // Rope'd Q embeddings CopyMat(x, qactual); CopyMat(x, qexpected); - ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv, - inv_timescale.Packed(), pos); - RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos); + ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), + pos); + RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); for (int i = 0; i < dim_qkv; ++i) { - EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4) - << "qIndex:" << i << "qInput:" << qactual.Packed()[i]; + EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4) + << "qIndex:" << i << "qInput:" << qactual.Row(0)[i]; } // Rope'd K embeddings CopyMat(x, kactual); CopyMat(x, kexpected); - ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv, - inv_timescale.Packed(), pos); - RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos); + ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0), + pos); + RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); for (int i = 0; i < dim_qkv; ++i) { - EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4) - << "kIndex:" << i << "kInput:" << kactual.Packed()[i]; + EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4) + << "kIndex:" << i << "kInput:" << kactual.Row(0)[i]; } } } diff --git a/python/configs.cc b/python/configs.cc index db40bc0..2fa6252 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -53,9 +53,7 @@ PYBIND11_MODULE(configs, py_module) { .value("kF32", Type::kF32) .value("kBF16", Type::kBF16) .value("kSFP", Type::kSFP) - .value("kNUQ", Type::kNUQ) - .value("kF64", Type::kF64) - .value("kC64", Type::kC64); + .value("kNUQ", Type::kNUQ); enum_(py_module, "LayerAttentionType") .value("kGemma", LayerAttentionType::kGemma) diff --git a/util/mat.cc b/util/mat.cc index 28763ba..3344cad 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -18,8 +18,6 @@ #include #include -#include - #include "util/threading_context.h" #include "hwy/base.h" #include "hwy/per_target.h" // VectorBytes @@ -62,35 +60,6 @@ void ZeroInit(MatPtr& mat) { } } -void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) { - PROFILER_FUNC; - HWY_ASSERT_M(mat.HasPtr(), mat.Name()); - // Only generates float/double for use by backprop/. - HWY_ASSERT(mat.GetType() == Type::kF32 || mat.GetType() == Type::kF64); - mat.SetScale(1.0f); - - std::normal_distribution dist(0.0, stddev); - if (mat.GetType() == Type::kF32) { - MatPtrT mat_f(mat); - - for (size_t r = 0; r < mat.Rows(); ++r) { - float* HWY_RESTRICT row = mat_f.Row(r); - for (size_t c = 0; c < mat.Cols(); ++c) { - row[c] = dist(gen); - } - } - } else { - MatPtrT mat_d(mat); - - for (size_t r = 0; r < mat.Rows(); ++r) { - double* HWY_RESTRICT row = mat_d.Row(r); - for (size_t c = 0; c < mat.Cols(); ++c) { - row[c] = dist(gen); - } - } - } -} - size_t Stride(MatPadding padding, size_t cols, size_t element_bytes, size_t line_bytes) { switch (padding) { diff --git a/util/mat.h b/util/mat.h index 0314e20..a3de89b 100644 --- a/util/mat.h +++ b/util/mat.h @@ -20,7 +20,6 @@ #include #include -#include #include // IWYU pragma: begin_exports @@ -229,21 +228,11 @@ class MatPtrT : public MatPtr { MatPtrT(const MatPtrT& other) = default; MatPtrT& operator=(const MatPtrT& other) = default; - // Returns the entire tensor for use by `backprop/*`. Verifies layout is - // `kPacked`. Preferably call `Row` instead, which works for either layout. - MatT* Packed() { - HWY_DASSERT_M(IsPacked(), name_.c_str()); - return HWY_RCAST_ALIGNED(MatT*, ptr_); - } - const MatT* Packed() const { - HWY_DASSERT_M(IsPacked(), name_.c_str()); - return HWY_RCAST_ALIGNED(const MatT*, ptr_); - } - // As `Packed()`, plus checks the scale is 1.0 because callers will ignore it. - // This is typically used for `MatMul` bias vectors and norm weights. + // Returns the entire tensor after checking the scale is 1.0 because callers + // will ignore it. Used for `MatMul` bias vectors and norm weights. const MatT* PackedScale1() const { HWY_DASSERT(Scale() == 1.0f); - return Packed(); + return HWY_RCAST_ALIGNED(const MatT*, ptr_); } MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); } @@ -268,8 +257,7 @@ class MatPtrT : public MatPtr { }; // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT`, plus the -// optional `args`. This supports all types used as weights, which excludes -// `kC64` and `kF64` (used only in `backprop/`). +// optional `args`. This supports all types used as weights. template decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, Args&&... args) { @@ -334,8 +322,6 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func, void CopyMat(const MatPtr& from, MatPtr& to); void ZeroInit(MatPtr& mat); -// F32/F64 only. -void RandInit(MatPtr& mat, float stddev, std::mt19937& gen); // Our tensors are always row-major. This enum indicates how much (if any) // padding comes after each row. @@ -346,9 +332,6 @@ enum class MatPadding { // `BlobStore`) are not padded, which also extends to memory-mapped tensors. // However, `BlobStore` is able to insert padding via row-wise I/O when // reading from disk via `Mode::kRead`. - // - // `backprop/*` also requires this layout because it indexes directly into - // the storage instead of calling `Row()`. kPacked, // Enough to round up to an odd number of cache lines, which can reduce // cache conflict misses or 4K aliasing. @@ -378,9 +361,9 @@ class MatOwner { AlignedPtr storage_; }; -// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and -// tests to allocate and access tensors of a known type. By contrast, the -// heterogeneous model weights are owned by vectors of `MatOwner`. +// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by tests to allocate +// and access tensors of a known type. By contrast, the heterogeneous model +// weights are owned by vectors of `MatOwner`. template class MatStorageT : public MatPtrT { public: @@ -393,7 +376,7 @@ class MatStorageT : public MatPtrT { : MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {} ~MatStorageT() = default; - // Allow move for backprop/activations. + // Allow move for KVCache. MatStorageT(MatStorageT&&) = default; MatStorageT& operator=(MatStorageT&&) = default; @@ -401,13 +384,6 @@ class MatStorageT : public MatPtrT { MatOwner owner_; }; -// Helper factory function for use by `backprop/` to avoid specifying the -// `MatPadding` argument everywhere. -template -MatStorageT MakePacked(const char* name, size_t rows, size_t cols) { - return MatStorageT(name, Extents2D(rows, cols), MatPadding::kPacked); -} - // Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with // seekable (non-NUQ) T. #pragma pack(push, 1) // power of two size