mirror of https://github.com/google/gemma.cpp.git
Remove backprop/
Also remove MatPtrT::Packed(); use PackedScale1 instead where const, or Row(0). PiperOrigin-RevId: 764243198
This commit is contained in:
parent
627cc04db9
commit
3890eb5412
128
BUILD.bazel
128
BUILD.bazel
|
|
@ -685,131 +685,3 @@ cc_binary(
|
|||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "prompt",
|
||||
hdrs = ["backprop/prompt.h"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sampler",
|
||||
hdrs = ["backprop/sampler.h"],
|
||||
deps = [
|
||||
":prompt",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "backprop",
|
||||
srcs = [
|
||||
"backprop/backward.cc",
|
||||
"backprop/forward.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward.h",
|
||||
"backprop/forward.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"backprop/backward-inl.h",
|
||||
"backprop/forward-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":configs",
|
||||
":mat",
|
||||
":ops",
|
||||
":prompt",
|
||||
":weights",
|
||||
"@highway//:dot",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "backprop_scalar",
|
||||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward_scalar.h",
|
||||
"backprop/common_scalar.h",
|
||||
"backprop/forward_scalar.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":configs",
|
||||
":mat",
|
||||
":prompt",
|
||||
":weights",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "backward_test",
|
||||
size = "large",
|
||||
srcs = [
|
||||
"backprop/backward_test.cc",
|
||||
"backprop/test_util.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":backprop",
|
||||
":backprop_scalar",
|
||||
":configs",
|
||||
":mat",
|
||||
":ops",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "optimizer",
|
||||
srcs = ["backprop/optimizer.cc"],
|
||||
hdrs = ["backprop/optimizer.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":mat",
|
||||
":weights",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "optimize_test",
|
||||
srcs = ["backprop/optimize_test.cc"],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":allocator",
|
||||
":backprop",
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":optimizer",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading",
|
||||
":tokenizer",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:types",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,18 +41,6 @@ FetchContent_MakeAvailable(benchmark)
|
|||
|
||||
# Base source files
|
||||
set(SOURCES
|
||||
backprop/activations.h
|
||||
backprop/backward_scalar.h
|
||||
backprop/backward-inl.h
|
||||
backprop/backward.cc
|
||||
backprop/backward.h
|
||||
backprop/common_scalar.h
|
||||
backprop/forward_scalar.h
|
||||
backprop/forward-inl.h
|
||||
backprop/forward.cc
|
||||
backprop/forward.h
|
||||
backprop/optimizer.cc
|
||||
backprop/optimizer.h
|
||||
compression/compress-inl.h
|
||||
compression/compress.cc
|
||||
compression/compress.h
|
||||
|
|
@ -202,8 +190,6 @@ enable_testing()
|
|||
include(GoogleTest)
|
||||
|
||||
set(GEMMA_TEST_FILES
|
||||
backprop/backward_test.cc
|
||||
backprop/optimize_test.cc
|
||||
compression/compress_test.cc
|
||||
compression/distortion_test.cc
|
||||
compression/nuq_test.cc
|
||||
|
|
|
|||
|
|
@ -173,9 +173,6 @@ more custom you can call transformer which performs a single inference operation
|
|||
on a single token and mutates the Activations and the KVCache through the neural
|
||||
network computation.
|
||||
|
||||
Note that an experimental backward pass is available in backprop/, which may be
|
||||
useful for fine tuning.
|
||||
|
||||
### For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
|
||||
You use `ops.h` if you're writing other NN architectures or modifying the
|
||||
|
|
|
|||
|
|
@ -1,87 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T>
|
||||
struct ForwardLayer {
|
||||
ForwardLayer(const LayerConfig& config, size_t seq_len)
|
||||
: input(MakePacked<T>("input", seq_len, config.model_dim)),
|
||||
pre_att_rms_out(
|
||||
MakePacked<T>("pre_att_rms_out", seq_len, config.model_dim)),
|
||||
qkv(MakePacked<T>("qkv", seq_len * (config.heads + 2), config.qkv_dim)),
|
||||
att(MakePacked<T>("att", seq_len * config.heads, seq_len)),
|
||||
att_out(
|
||||
MakePacked<T>("att_out", seq_len * config.heads, config.qkv_dim)),
|
||||
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
|
||||
attention_out(
|
||||
MakePacked<T>("attention_out", seq_len, config.model_dim)),
|
||||
pre_ffw_rms_out(
|
||||
MakePacked<T>("preFF_rms_out", seq_len, config.model_dim)),
|
||||
ffw_hidden(
|
||||
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
|
||||
ffw_hidden_gated(
|
||||
MakePacked<T>("ffw_hidden_gated", seq_len, config.ff_hidden_dim)),
|
||||
layer_config(config) {}
|
||||
|
||||
MatStorageT<T> input;
|
||||
MatStorageT<T> pre_att_rms_out;
|
||||
MatStorageT<T> qkv;
|
||||
MatStorageT<T> att;
|
||||
MatStorageT<T> att_out;
|
||||
MatStorageT<T> att_post1;
|
||||
MatStorageT<T> attention_out;
|
||||
MatStorageT<T> pre_ffw_rms_out;
|
||||
MatStorageT<T> ffw_hidden;
|
||||
MatStorageT<T> ffw_hidden_gated;
|
||||
const LayerConfig& layer_config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ForwardPass {
|
||||
ForwardPass(const ModelConfig& config)
|
||||
: final_layer_output(
|
||||
MakePacked<T>("fin_layer_out", config.seq_len, config.model_dim)),
|
||||
final_norm_output(
|
||||
MakePacked<T>("fin_norm_out", config.seq_len, config.model_dim)),
|
||||
logits(MakePacked<T>("logits", config.seq_len, config.vocab_size)),
|
||||
probs(MakePacked<T>("probs", config.seq_len, config.vocab_size)),
|
||||
weights_config(config) {
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
layers.emplace_back(layer_config, config.seq_len);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ForwardLayer<T>> layers;
|
||||
MatStorageT<T> final_layer_output;
|
||||
MatStorageT<T> final_norm_output;
|
||||
MatStorageT<T> logits;
|
||||
MatStorageT<T> probs;
|
||||
const ModelConfig& weights_config;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
|
|
@ -1,407 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Implementation of the Vector-Jacobian Products (VJP) of the individual
|
||||
// operations of the forward pass.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/configs.h" // LayerConfig, ModelConfig
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
|
||||
const float* HWY_RESTRICT x, // num_tokens * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t cols, size_t rows, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // kRows * kCols,
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t voffs = pos * rows;
|
||||
const size_t xoffs = pos * cols;
|
||||
for (size_t j = 0; j < rows; ++j) {
|
||||
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols);
|
||||
MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HWY_INLINE void MultiHeadMatMulVJP(
|
||||
const float* HWY_RESTRICT weights, // heads * kRows * kCols
|
||||
const float* HWY_RESTRICT x, // num_tokens * heads * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t heads, size_t cols, size_t rows, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // heads * kRows * kCols
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t j = 0; j < rows; ++j) {
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols],
|
||||
&grad_w[h * rows * cols + j * cols], cols);
|
||||
MulByConstAndAdd(v[pos * rows + j],
|
||||
&weights[h * rows * cols + j * cols],
|
||||
&grad_x[pos * heads * cols + h * cols], cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
|
||||
// kSqrtOverPi*3*kMul
|
||||
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
|
||||
|
||||
const hn::Vec<D> v2 = hn::Mul(v, v);
|
||||
const hn::Vec<D> v3 = hn::Mul(v2, v);
|
||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||
const hn::Vec<D> tanh = hn::Tanh(d, arg);
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
|
||||
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
|
||||
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
|
||||
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto offset =
|
||||
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
|
||||
hn::Transform1(
|
||||
d, backward, size, forward,
|
||||
[&offset](const auto d, const auto v, const auto y)
|
||||
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP(
|
||||
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
|
||||
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
|
||||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
const float ss = detail::RMSNormMul(x + offset, model_dim);
|
||||
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
grad_w[i] += v[offset + i] * x[offset + i] * ss;
|
||||
}
|
||||
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
|
||||
float tmp = 0.0f;
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
|
||||
}
|
||||
tmp *= ss3;
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
|
||||
tmp * x[offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP(
|
||||
const float* weights, const std::vector<int>& prompt, const float scaling,
|
||||
const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
MulByConstAndAdd(scaling, v + pos * model_dim,
|
||||
grad + token * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||
const ForwardLayer<float>& forward,
|
||||
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
|
||||
const MatStorageT<float>& inv_timescale, hwy::ThreadPool& pool) {
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t qkv_dim = config.qkv_dim;
|
||||
const size_t heads = config.heads;
|
||||
const size_t seq_len = forward.input.Rows();
|
||||
const size_t ff_hidden_dim = config.ff_hidden_dim;
|
||||
const float query_scale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
|
||||
HWY_ASSERT(num_tokens <= seq_len);
|
||||
|
||||
MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(),
|
||||
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
|
||||
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * ff_hidden_dim * 2;
|
||||
const float* HWY_RESTRICT f_out =
|
||||
forward.ffw_hidden.Packed() + hidden_offset;
|
||||
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
|
||||
const float* HWY_RESTRICT b_out_gated =
|
||||
backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim;
|
||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset;
|
||||
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) {
|
||||
const auto y = Load(df, f_out + i);
|
||||
const auto x = Load(df, f_out_mul + i);
|
||||
const auto v = Load(df, b_out_gated + i);
|
||||
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
|
||||
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
|
||||
backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2,
|
||||
num_tokens, grad.gating_einsum_w.Packed(),
|
||||
backward.pre_ffw_rms_out.Packed(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.Packed(),
|
||||
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
|
||||
model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(),
|
||||
backward.attention_out.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(next_layer_grad + pos * model_dim,
|
||||
backward.attention_out.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
ZeroInit(backward.qkv);
|
||||
|
||||
MultiHeadMatMulVJP(
|
||||
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
|
||||
backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens,
|
||||
grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
|
||||
const float* HWY_RESTRICT b_att_out =
|
||||
backward.att_out.Packed() + (pos * heads + head) * qkv_dim;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs;
|
||||
float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs;
|
||||
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
|
||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
|
||||
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs;
|
||||
const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs;
|
||||
float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs;
|
||||
float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs;
|
||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||
float* HWY_RESTRICT b_kv =
|
||||
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||
Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos);
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT b_q =
|
||||
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
|
||||
MulByConst(query_scale, b_q, qkv_dim);
|
||||
Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
|
||||
backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
||||
grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(),
|
||||
pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
|
||||
backward.pre_att_rms_out.Packed(), model_dim, num_tokens,
|
||||
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
|
||||
pool);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(backward.attention_out.Packed() + pos * model_dim,
|
||||
backward.input.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftcapVJP(const float cap,
|
||||
const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto one = hn::Set(d, 1.0f);
|
||||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
hn::Transform1(d, backward, size, forward,
|
||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
|
||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
|
||||
const Prompt& prompt, size_t vocab_size) {
|
||||
HWY_ASSERT(!prompt.tokens.empty());
|
||||
const float scaling = -1.0 / std::log(2.0);
|
||||
size_t num_tokens = prompt.tokens.size() - 1;
|
||||
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if (pos + 1 < prompt.context_size) {
|
||||
continue;
|
||||
}
|
||||
const int next_token = prompt.tokens[pos + 1];
|
||||
grad[pos * vocab_size + next_token] =
|
||||
scaling / x[pos * vocab_size + next_token];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t kVocabSize = config.vocab_size;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t kLayers = config.layer_configs.size();
|
||||
const float kEmbScaling = EmbeddingScaling(model_dim);
|
||||
HWY_ASSERT(!config.absolute_pe);
|
||||
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
|
||||
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
|
||||
|
||||
HWY_DASSERT(prompt.context_size > 0);
|
||||
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
|
||||
kVocabSize);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize,
|
||||
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize,
|
||||
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.embedder_input_embedding.Packed(),
|
||||
forward.final_norm_output.Packed(), backward.logits.Packed(),
|
||||
model_dim, kVocabSize, num_tokens,
|
||||
grad.embedder_input_embedding.Packed(),
|
||||
backward.final_norm_output.Packed(), pool);
|
||||
|
||||
RMSNormVJP(weights.final_norm_scale.Packed(),
|
||||
forward.final_layer_output.Packed(),
|
||||
backward.final_norm_output.Packed(), model_dim, num_tokens,
|
||||
grad.final_norm_scale.Packed(),
|
||||
backward.final_layer_output.Packed(), pool);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
auto layer_config = config.layer_configs[layer];
|
||||
// TODO(szabadka) Implement Griffin layer vjp.
|
||||
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
|
||||
float* next_layer_grad = layer + 1 < kLayers
|
||||
? backward.layers[layer + 1].input.Packed()
|
||||
: backward.final_layer_output.Packed();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens,
|
||||
kEmbScaling, backward.layers[0].input.Packed(),
|
||||
grad.embedder_input_embedding.Packed(), model_dim);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/backward.h"
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CrossEntropyLossBackwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CrossEntropyLossBackwardPassT);
|
||||
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
|
@ -1,352 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
template<typename T>
|
||||
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t N, size_t M, size_t K) {
|
||||
memset(dx, 0, M * K * sizeof(dx[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
|
||||
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t H, size_t N, size_t M, size_t K) {
|
||||
memset(dx, 0, H * M * K * sizeof(dx[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
for (size_t h = 0; h < H; ++h) {
|
||||
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
|
||||
&dw[h * N * M + j * M], M);
|
||||
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
|
||||
&dx[i * H * M + h * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
constexpr T eps(1e-6);
|
||||
T ss = SquaredL2(x + i * N, N);
|
||||
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
|
||||
}
|
||||
const T ss3 = ss * ss * ss / T(N);
|
||||
T tmp = 0.0;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
|
||||
}
|
||||
tmp *= ss3;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += y[i] * dy[i];
|
||||
}
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
dy[i] = y[i] * (dy[i] - sum);
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
SoftmaxVJPT(y + i * N, dy + i * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T GeluDerivative(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
|
||||
|
||||
const T x2 = x * x;
|
||||
const T x3 = x2 * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T tanh = std::tanh(arg);
|
||||
const T cdf = T(0.5) * (T(1.0) + tanh);
|
||||
const T dtanh = T(1.0) - tanh * tanh;
|
||||
const T darg = kMul2 * x2 + kSqrt2OverPi;
|
||||
return T(0.5) * x * dtanh * darg + cdf;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const T* x1 = in + i * 2 * N;
|
||||
const T* x2 = x1 + N;
|
||||
const T* v = d_out + i * N;
|
||||
T* dx1 = d_in + i * 2 * N;
|
||||
T* dx2 = dx1 + N;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
|
||||
dx2[j] = v[j] * Gelu(x1[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
|
||||
size_t num_tokens, size_t kHeads, size_t qkv_dim,
|
||||
size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * (kHeads + 2) * qkv_dim;
|
||||
memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * kHeads * seq_len;
|
||||
const T* q = qkv + qoffs;
|
||||
const T* dout = doutput + aoffs;
|
||||
T* dq = dqkv + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim;
|
||||
const T* k = qkv + koffs;
|
||||
T* dk = dqkv + koffs;
|
||||
MulByConstAndAddT(dout[pos2], k, dq, qkv_dim);
|
||||
MulByConstAndAddT(dout[pos2], q, dk, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads,
|
||||
size_t seq_len) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
size_t offset = pos * kHeads * seq_len + head * seq_len;
|
||||
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
|
||||
memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
|
||||
T* dqkv, T* dattention, size_t num_tokens, size_t kHeads,
|
||||
size_t qkv_dim, size_t seq_len) {
|
||||
auto v_offset = [&](size_t pos) {
|
||||
return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim;
|
||||
};
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim;
|
||||
const size_t aoffset = head * seq_len + pos * kHeads * seq_len;
|
||||
const T* att = &attention[aoffset];
|
||||
const T* dout = &doutput[offset];
|
||||
T* datt = &dattention[aoffset];
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim);
|
||||
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
|
||||
const T* dy, T* dw, size_t N) {
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
int token = tokens[i];
|
||||
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||
const ForwardLayer<T>& forward, const T* dy,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<T>& backward,
|
||||
size_t num_tokens) {
|
||||
const LayerConfig& layer_config = weights.layer_config;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t seq_len = forward.input.Rows();
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kHeads = layer_config.heads;
|
||||
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
|
||||
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
|
||||
|
||||
MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy,
|
||||
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||
model_dim, kFFHiddenDim, num_tokens);
|
||||
|
||||
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
|
||||
backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(),
|
||||
backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(),
|
||||
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
|
||||
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);
|
||||
|
||||
MultiHeadMatMulVJPT(
|
||||
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
|
||||
backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(),
|
||||
backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens);
|
||||
|
||||
MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(),
|
||||
backward.att_out.Packed(), backward.qkv.Packed(),
|
||||
backward.att.Packed(), num_tokens, kHeads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens,
|
||||
kHeads, seq_len);
|
||||
|
||||
MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(),
|
||||
backward.qkv.Packed(), num_tokens, kHeads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
|
||||
backward.qkv.Packed(), grad.qkv_einsum_w.Packed(),
|
||||
backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim,
|
||||
model_dim, num_tokens);
|
||||
RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
|
||||
backward.pre_att_rms_out.Packed(),
|
||||
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(backward.attention_out.Packed(), backward.input.Packed(),
|
||||
num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
|
||||
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
T scaled = y[i] * inv_cap; // tanh
|
||||
dy[i] *= (T{1.0} - scaled * scaled);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
|
||||
T scaling = -1.0 / std::log(2.0);
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
memset(dx, 0, V * num_tokens * sizeof(x[0]));
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
if (i + 1 < prompt.context_size) {
|
||||
continue;
|
||||
}
|
||||
const int next_token = tokens[i + 1];
|
||||
dx[i * V + next_token] = scaling / x[i * V + next_token];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const ForwardPass<T>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<T>& backward) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
|
||||
vocab_size);
|
||||
|
||||
SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size,
|
||||
num_tokens);
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size,
|
||||
backward.logits.Packed() + i * vocab_size, vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.embedder_input_embedding.Packed(),
|
||||
forward.final_norm_output.Packed(), backward.logits.Packed(),
|
||||
grad.embedder_input_embedding.Packed(),
|
||||
backward.final_norm_output.Packed(), vocab_size, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(
|
||||
weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(),
|
||||
backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(),
|
||||
backward.final_layer_output.Packed(), model_dim, num_tokens);
|
||||
|
||||
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
|
||||
T* next_layer_grad = layer + 1 < layers
|
||||
? backward.layers[layer + 1].input.Packed()
|
||||
: backward.final_layer_output.Packed();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
||||
}
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens,
|
||||
kEmbScaling, backward.layers[0].input.Packed(),
|
||||
grad.embedder_input_embedding.Packed(), model_dim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <complex>
|
||||
#include <cstdlib> // std::abs
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h" // DotT
|
||||
#include "backprop/forward_scalar.h" // MatMulT
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
// 'include guard' so we only define this once. Note that HWY_ONCE is only
|
||||
// defined during the last pass, but this is used in each pass.
|
||||
#ifndef BACKWARD_TEST_ONCE
|
||||
#define BACKWARD_TEST_ONCE
|
||||
// TestEndToEnd is slow, so only run it for the best-available target.
|
||||
static int run_once;
|
||||
#endif
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
hwy::ThreadPool& ThreadHostileGetPool() {
|
||||
// Assume this is only called at the top level, i.e. not in a thread. Then we
|
||||
// can safely call `SetArgs` only once, because it would assert otherwise.
|
||||
// This is preferable to calling `ThreadHostileInvalidate`, because we would
|
||||
// repeat the topology initialization for every test.
|
||||
if (!ThreadingContext::IsInitialized()) {
|
||||
gcpp::ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 8;
|
||||
threading_args.pin = Tristate::kFalse;
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
}
|
||||
return ThreadingContext::Get().pools.Pool();
|
||||
}
|
||||
|
||||
void TestMatMulVJP() {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
std::mt19937 gen(42);
|
||||
auto weights = MakePacked<float>("weights", kRows, kCols);
|
||||
auto x = MakePacked<float>("x", kTokens, kCols);
|
||||
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||
auto grad = MakePacked<float>("grad", kRows, kCols);
|
||||
auto dx = MakePacked<float>("dx", kTokens, kCols);
|
||||
using TC = std::complex<double>;
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
|
||||
kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
|
||||
ZeroInit(grad);
|
||||
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
|
||||
grad.Packed(), dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestMultiHeadMatMulVJP() {
|
||||
static const size_t kRows = 2;
|
||||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
std::mt19937 gen(42);
|
||||
auto weights = MakePacked<float>("weights", kRows, kCols * kHeads);
|
||||
auto x = MakePacked<float>("x", kTokens, kCols * kHeads);
|
||||
auto grad = MakePacked<float>("grad", kRows, kCols * kHeads);
|
||||
auto dx = MakePacked<float>("dx", kTokens, kCols * kHeads);
|
||||
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||
using TC = std::complex<double>;
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
|
||||
kRows, kCols, kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
|
||||
ZeroInit(grad);
|
||||
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
|
||||
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestRMSNormVJP() {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
std::mt19937 gen(42);
|
||||
auto weights = MakePacked<float>("weights", N, 1);
|
||||
auto x = MakePacked<float>("x", K, N);
|
||||
auto grad = MakePacked<float>("grad", N, 1);
|
||||
auto dx = MakePacked<float>("dx", K, N);
|
||||
auto dy = MakePacked<float>("dy", K, N);
|
||||
using TC = std::complex<double>;
|
||||
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||
};
|
||||
|
||||
ZeroInit(grad);
|
||||
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
|
||||
dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestEndToEnd() {
|
||||
if (++run_once > 1) return; // ~3 min on SKX, only run best available target
|
||||
|
||||
std::mt19937 gen(42);
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
ModelConfig config(Model::GEMMA_TINY, Type::kF32, PromptWrapping::GEMMA_IT);
|
||||
WeightsWrapper<float> weights(config);
|
||||
WeightsWrapper<float> grad(config);
|
||||
ForwardPass<float> forward0(config);
|
||||
ForwardPass<float> forward1(config);
|
||||
ForwardPass<float> backward(config);
|
||||
using TC = std::complex<double>;
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
weights.get().RandInit(1.0f, gen);
|
||||
|
||||
float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
|
||||
|
||||
float loss1 = CrossEntropyLossForwardPass(
|
||||
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
||||
inv_timescale, pool);
|
||||
|
||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||
|
||||
grad.get().ZeroInit();
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||
backward, inv_timescale, pool);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 2e-3f, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(BackwardTest);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "util/mat.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template<typename T, typename U>
|
||||
U DotT(const T* a, const U* b, size_t N) {
|
||||
U sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline std::complex<double> DotT(const float* a, const std::complex<double>* b,
|
||||
size_t N) {
|
||||
std::complex<double> sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += static_cast<double>(a[i]) * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MulByConstT(T c, T* x, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
// out += c * x
|
||||
template<typename T>
|
||||
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += c * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void AddFromT(const T* a, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += a[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T SquaredL2(const T* x, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += x[i] * x[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T Gelu(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
|
||||
const T x3 = x * x * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template<typename T, typename U>
|
||||
void Rope(T* x, U base, size_t N, int i) {
|
||||
const size_t N2 = N / 2;
|
||||
for (size_t dim = 0; dim < N2; ++dim) {
|
||||
const T freq_exponents = T(2 * dim) / T(N);
|
||||
const T timescale = std::pow(base, freq_exponents);
|
||||
const T theta = T(i) / timescale;
|
||||
const T cos_val = std::cos(theta);
|
||||
const T sin_val = std::sin(theta);
|
||||
const T x0 = x[dim];
|
||||
const T x1 = x[dim + N2];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + N2] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(T* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(std::complex<T>* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
|
@ -1,297 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename T>
|
||||
void InputEmbedding(const MatPtrT<T>& weights, const std::vector<int>& prompt,
|
||||
const float scaling, float* HWY_RESTRICT output,
|
||||
size_t model_dim, size_t vocab_size) {
|
||||
const hn::ScalableTag<float> df;
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
const auto span = weights.Span();
|
||||
HWY_ASSERT(span.num == model_dim * vocab_size);
|
||||
DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim,
|
||||
model_dim);
|
||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename WT, typename XT, typename OutT>
|
||||
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
|
||||
size_t model_dim, size_t num_tokens,
|
||||
OutT* HWY_RESTRICT output,
|
||||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
RMSNorm(x + offset, weights, 0, output + offset, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
|
||||
const std::vector<int>& prompt,
|
||||
size_t context_size,
|
||||
size_t vocab_size,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
float loss = 0.0f;
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
if (pos + 1 < context_size) {
|
||||
continue; // next token is part of context, don't try to predict it
|
||||
}
|
||||
const int next_token = prompt[pos + 1];
|
||||
loss += std::log(probs[pos * vocab_size + next_token]);
|
||||
}
|
||||
float scaling = -1.0 / std::log(2.0);
|
||||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<float>& activations, size_t num_tokens,
|
||||
float* HWY_RESTRICT output,
|
||||
const MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t kSeqLen = activations.input.Rows();
|
||||
const size_t kQKVDim = config.qkv_dim;
|
||||
const size_t kHeads = config.heads;
|
||||
static const float query_scale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
|
||||
ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(),
|
||||
activations.input.Packed(), model_dim, num_tokens,
|
||||
activations.pre_att_rms_out.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
|
||||
activations.pre_att_rms_out.Packed() + pos * model_dim,
|
||||
activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos);
|
||||
}
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos);
|
||||
MulByConst(query_scale, q, kQKVDim);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT q =
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const float* HWY_RESTRICT k2 =
|
||||
activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||
Softmax(head_att, pos + 1);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT head_att =
|
||||
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
float* HWY_RESTRICT v2 = activations.qkv.Packed() +
|
||||
(pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
});
|
||||
|
||||
ZeroInit(activations.attention_out);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
||||
kQKVDim,
|
||||
activations.att_out.Packed() + pos * kHeads * kQKVDim +
|
||||
head * kQKVDim,
|
||||
activations.att_post1.Packed() + pos * model_dim, pool);
|
||||
AddFrom(activations.att_post1.Packed() + pos * model_dim,
|
||||
activations.attention_out.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.input.Packed() + pos * model_dim,
|
||||
activations.attention_out.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
|
||||
activations.attention_out.Packed(), model_dim, num_tokens,
|
||||
activations.pre_ffw_rms_out.Packed(), pool);
|
||||
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
||||
activations.pre_ffw_rms_out.Packed() + pos * model_dim,
|
||||
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||
const float* HWY_RESTRICT out =
|
||||
activations.ffw_hidden.Packed() + hidden_offset;
|
||||
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
||||
float* HWY_RESTRICT out_gated =
|
||||
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
||||
const auto y = hn::Load(df, out + i);
|
||||
const auto x = hn::Load(df, out_mul + i);
|
||||
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
|
||||
}
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
|
||||
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim,
|
||||
output + pos * model_dim, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.attention_out.Packed() + pos * model_dim,
|
||||
output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||
size_t context_size,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
const MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
HWY_ASSERT(!config.absolute_pe);
|
||||
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
|
||||
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
|
||||
|
||||
HWY_DASSERT(context_size > 0);
|
||||
HWY_DASSERT(context_size < prompt.size());
|
||||
const size_t num_tokens = prompt.size() - 1;
|
||||
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
|
||||
forward.layers[0].input.Packed(), model_dim, vocab_size);
|
||||
|
||||
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
||||
auto type = config.layer_configs[layer].type;
|
||||
// TODO(szabadka) Implement Griffin layer.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
float* HWY_RESTRICT output = layer + 1 < layers
|
||||
? forward.layers[layer + 1].input.Packed()
|
||||
: forward.final_layer_output.Packed();
|
||||
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
|
||||
num_tokens, output, inv_timescale, pool);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.final_norm_scale.Packed(),
|
||||
forward.final_layer_output.Packed(), model_dim, num_tokens,
|
||||
forward.final_norm_output.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
|
||||
forward.final_norm_output.Packed() + pos * model_dim,
|
||||
forward.logits.Packed() + pos * vocab_size, pool);
|
||||
}
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(config.final_cap,
|
||||
forward.logits.Packed() + pos * vocab_size, vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
CopyMat(forward.logits, forward.probs);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size,
|
||||
vocab_size, pool);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/forward.h"
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
float CrossEntropyLossForwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
|
||||
weights, forward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CrossEntropyLossForwardPassT);
|
||||
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
||||
prompt, weights, forward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
MatStorageT<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
|
@ -1,298 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
|
||||
// y = w * x is N x K matrix in column-major order.
|
||||
template<typename T>
|
||||
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
|
||||
// column-major order and y = w' * x is N x K matrix in column-major order,
|
||||
// where w' is the rearrangement of w into an N x HM matrix.
|
||||
template<typename T>
|
||||
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
|
||||
size_t M, size_t K) {
|
||||
memset(y, 0, N * K * sizeof(y[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t h = 0; h < H; ++h) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
|
||||
constexpr T eps(1e-6);
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
T ss = SquaredL2(x + i * N, N);
|
||||
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
|
||||
for (size_t j = 0; j < N; j++) {
|
||||
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softmax(T* x, size_t N) {
|
||||
T sum = {};
|
||||
auto maxreal = std::real(x[0]);
|
||||
for (size_t i = 1; i < N; ++i) {
|
||||
if (std::real(x[i]) > maxreal) {
|
||||
maxreal = std::real(x[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = std::exp(x[i] - maxreal);
|
||||
sum += x[i];
|
||||
}
|
||||
T scale = T(1.0) / sum;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= scale;
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softmax(T* x, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
Softmax(x + i * N, N);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void Softcap(float cap, T* x, size_t N) {
|
||||
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const T* x1 = in + i * 2 * N;
|
||||
const T* x2 = x1 + N;
|
||||
T* y = out + i * N;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[j] = x2[j] * Gelu(x1[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
|
||||
T* y, size_t N) {
|
||||
HWY_ASSERT(w != nullptr);
|
||||
HWY_ASSERT(y != nullptr);
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
int token = tokens[i];
|
||||
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
|
||||
MulByConstT(scaling, y + i * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads,
|
||||
size_t qkv_dim, size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
const size_t qoffset = pos * (heads + 2) * qkv_dim;
|
||||
const size_t aoffset = pos * heads * seq_len + head * seq_len;
|
||||
const T* q = qkv + qoffset + head * qkv_dim;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
output[aoffset + pos2] = DotT(q, k, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
size_t offset = pos * heads * seq_len + head * seq_len;
|
||||
Softmax(x + offset, pos + 1);
|
||||
memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void MixByAttention(const T* qkv, const T* attention, T* output,
|
||||
size_t num_tokens, size_t heads, size_t qkv_dim,
|
||||
size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
const T* att = &attention[pos * heads * seq_len + head * seq_len];
|
||||
T* out = &output[head * qkv_dim + pos * heads * qkv_dim];
|
||||
memset(out, 0, qkv_dim * sizeof(out[0]));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const T* v = &qkv[v_offset];
|
||||
MulByConstAndAddT(att[pos2], v, out, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void ApplyLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<T>& activations, size_t num_tokens, T* output) {
|
||||
const LayerConfig& layer_config = weights.layer_config;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t seq_len = activations.input.Rows();
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
|
||||
|
||||
RMSNormT(weights.pre_attention_norm_scale.Packed(),
|
||||
activations.input.Packed(), activations.pre_att_rms_out.Packed(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(),
|
||||
activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim,
|
||||
num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= heads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
|
||||
MulByConstT(query_scale, qkv, heads * qkv_dim);
|
||||
}
|
||||
|
||||
MaskedAttention(activations.qkv.Packed(), activations.att.Packed(),
|
||||
num_tokens, heads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len);
|
||||
|
||||
MixByAttention(activations.qkv.Packed(), activations.att.Packed(),
|
||||
activations.att_out.Packed(), num_tokens, heads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(),
|
||||
activations.att_out.Packed(),
|
||||
activations.attention_out.Packed(), heads, model_dim, qkv_dim,
|
||||
num_tokens);
|
||||
|
||||
AddFromT(activations.input.Packed(), activations.attention_out.Packed(),
|
||||
num_tokens * model_dim);
|
||||
|
||||
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
|
||||
activations.attention_out.Packed(),
|
||||
activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.gating_einsum_w.Packed(),
|
||||
activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(),
|
||||
ff_hidden_dim * 2, model_dim, num_tokens);
|
||||
|
||||
GatedGelu(activations.ffw_hidden.Packed(),
|
||||
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(),
|
||||
output, model_dim, ff_hidden_dim, num_tokens);
|
||||
|
||||
AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
|
||||
T loss = {};
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
if (i + 1 < prompt.context_size) {
|
||||
continue; // next token is part of context, don't try to predict it
|
||||
}
|
||||
const int next_token = tokens[i + 1];
|
||||
loss += std::log(x[i * V + next_token]);
|
||||
}
|
||||
T scaling = -1.0 / std::log(2.0);
|
||||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<T>& forward) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling,
|
||||
forward.layers[0].input.Packed(), model_dim);
|
||||
|
||||
for (size_t layer = 0; layer < layers; ++layer) {
|
||||
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed()
|
||||
: forward.final_layer_output.Packed();
|
||||
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
||||
output);
|
||||
}
|
||||
|
||||
RMSNormT(weights.final_norm_scale.Packed(),
|
||||
forward.final_layer_output.Packed(),
|
||||
forward.final_norm_output.Packed(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.embedder_input_embedding.Packed(),
|
||||
forward.final_norm_output.Packed(), forward.logits.Packed(),
|
||||
vocab_size, model_dim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if (config.final_cap > 0.0f) {
|
||||
Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
CopyMat(forward.logits, forward.probs);
|
||||
Softmax(forward.probs.Packed(), vocab_size, num_tokens);
|
||||
|
||||
return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
gcpp::ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 1;
|
||||
threading_args.pin = Tristate::kFalse;
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext::Get());
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
ModelConfig config(Model::GEMMA_TINY, Type::kF32,
|
||||
ChooseWrapping(Model::GEMMA_TINY));
|
||||
config.eos_id = ReverseSequenceSampler::kEndToken;
|
||||
|
||||
WeightsOwner grad(Type::kF32), grad_m(Type::kF32), grad_v(Type::kF32);
|
||||
grad.AllocateForTest(config, pool);
|
||||
grad_m.AllocateForTest(config, pool);
|
||||
grad_v.AllocateForTest(config, pool);
|
||||
grad_m.ZeroInit();
|
||||
grad_v.ZeroInit();
|
||||
ForwardPass<float> forward(config), backward(config);
|
||||
KVCache kv_cache(config, /*prefill_tbatch_size=*/16);
|
||||
|
||||
MatStorageT<float> inv_timescale = CreateInvTimescale(
|
||||
allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
|
||||
Gemma gemma(config, GemmaTokenizer(kMockTokenizer), env);
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
auto stream_token = [&reply](int token, float) {
|
||||
reply.push_back(token);
|
||||
return token != ReverseSequenceSampler::kEndToken;
|
||||
};
|
||||
RuntimeConfig runtime = {
|
||||
.max_generated_tokens = 16,
|
||||
.temperature = 1.0f,
|
||||
.gen = &gen,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Sanity check of reply tokens.
|
||||
// 1) Its length should be greater than the prompt.
|
||||
// 2) The prompt should be a prefix of the reply.
|
||||
auto verify = [&](const Prompt& prompt) {
|
||||
const std::vector<int>& context = prompt.context();
|
||||
std::vector<int> reply = generate(context);
|
||||
if (reply.size() <= context.size()) return false;
|
||||
return std::equal(context.begin(), context.end(), reply.begin(),
|
||||
reply.begin() + context.size());
|
||||
};
|
||||
|
||||
gemma.MutableWeights().RandInit(1.0f, gen);
|
||||
gemma.MutableWeights().Fixup(pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
gemma.MutableWeights().LogWeightStatsF32();
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
constexpr float kAlpha = 0.001f;
|
||||
constexpr float kBeta1 = 0.9f;
|
||||
constexpr float kBeta2 = 0.999f;
|
||||
constexpr float kEpsilon = 1e-8f;
|
||||
|
||||
constexpr float kMaxLoss = 20.0f;
|
||||
|
||||
ReverseSequenceSampler training_task({
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||
size_t steps = 0;
|
||||
size_t num_ok;
|
||||
for (; steps < 1000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
grad.ZeroInit();
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
Prompt prompt = training_task.Sample(sgen);
|
||||
total_loss += CrossEntropyLossForwardPass(
|
||||
prompt, *gemma.Weights().GetF32(), forward, inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
|
||||
*grad.GetF32(), backward, inv_timescale,
|
||||
pool);
|
||||
gemma.MutableWeights().Fixup(pool);
|
||||
num_ok += verify(prompt) ? 1 : 0;
|
||||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
AdamUpdate(grad, kAlpha, kBeta1, kBeta2, kEpsilon, steps + 1,
|
||||
gemma.Weights(), grad_m, grad_v, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
printf("Batch gradient:\n");
|
||||
grad.LogWeightStatsF32();
|
||||
}
|
||||
if (total_loss < kMaxLoss) break; // Done
|
||||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
gemma.MutableWeights().LogWeightStatsF32();
|
||||
EXPECT_LT(steps, 80);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/optimizer.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
|
||||
using MatPtrF = MatPtrT<float>;
|
||||
|
||||
// Split into two classes so that ForEachTensor only requires two "other"
|
||||
// arguments. This is anyway useful for locality, because `grad` only feeds
|
||||
// into `grad_m` and `grad_v` here.
|
||||
class AdamUpdateMV {
|
||||
public:
|
||||
AdamUpdateMV(float beta1, float beta2, size_t t)
|
||||
: beta1_(beta1),
|
||||
beta2_(beta2),
|
||||
cbeta1_(1.0f - beta1),
|
||||
cbeta2_(1.0f - beta2),
|
||||
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
|
||||
|
||||
void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) {
|
||||
for (size_t r = 0; r < grad.Rows(); ++r) {
|
||||
const float* HWY_RESTRICT g = grad.Row(r);
|
||||
float* HWY_RESTRICT m = grad_m.Row(r);
|
||||
float* HWY_RESTRICT v = grad_v.Row(r);
|
||||
for (size_t c = 0; c < grad.Cols(); ++c) {
|
||||
m[c] *= beta1_;
|
||||
m[c] += cbeta1_ * g[c];
|
||||
v[c] *= beta2_;
|
||||
v[c] += cbeta2_ * g[c] * g[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float beta1_;
|
||||
float beta2_;
|
||||
float cbeta1_;
|
||||
float cbeta2_;
|
||||
float norm1_;
|
||||
float norm2_;
|
||||
};
|
||||
|
||||
// Updates `weights` based on the updated `grad_m` and `grad_v` from above.
|
||||
class AdamUpdateW {
|
||||
public:
|
||||
AdamUpdateW(float alpha, float beta1, float beta2, float epsilon, size_t t)
|
||||
: alpha_(alpha),
|
||||
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
|
||||
epsilon_(epsilon) {}
|
||||
|
||||
void operator()(MatPtrF& weights, const MatPtrF& grad_m,
|
||||
const MatPtrF& grad_v) {
|
||||
for (size_t r = 0; r < weights.Rows(); ++r) {
|
||||
float* HWY_RESTRICT w = weights.Row(r);
|
||||
const float* HWY_RESTRICT m = grad_m.Row(r);
|
||||
const float* HWY_RESTRICT v = grad_v.Row(r);
|
||||
for (size_t c = 0; c < weights.Cols(); ++c) {
|
||||
const float mhat = m[c] * norm1_;
|
||||
const float vhat = v[c] * norm2_;
|
||||
w[c] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float norm1_;
|
||||
float norm2_;
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
ModelWeightsPtrs<float>* weights,
|
||||
ModelWeightsPtrs<float>* grad_m,
|
||||
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
||||
AdamUpdateMV update_mv(beta1, beta2, t);
|
||||
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
|
||||
const MatPtrF grad_f(t.mat);
|
||||
MatPtrF grad_m_f(*t.other_mat1);
|
||||
MatPtrF grad_v_f(*t.other_mat2);
|
||||
update_mv(grad_f, grad_m_f, grad_v_f);
|
||||
});
|
||||
|
||||
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
|
||||
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
|
||||
MatPtrF weights_f(t.mat);
|
||||
const MatPtrF grad_m_f(*t.other_mat1);
|
||||
const MatPtrF grad_v_f(*t.other_mat2);
|
||||
update_w(weights_f, grad_m_f, grad_v_f);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2,
|
||||
float epsilon, size_t t, const WeightsOwner& weights,
|
||||
const WeightsOwner& grad_m, const WeightsOwner& grad_v,
|
||||
hwy::ThreadPool& pool) {
|
||||
AdamUpdate(grad.GetF32(), alpha, beta1, beta2, epsilon, t, weights.GetF32(),
|
||||
grad_m.GetF32(), grad_v.GetF32(), pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void AdamUpdate(const WeightsOwner& grad, float alpha, float beta1, float beta2,
|
||||
float epsilon, size_t t, const WeightsOwner& weights,
|
||||
const WeightsOwner& grad_m, const WeightsOwner& grad_v,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct Prompt {
|
||||
std::vector<int> tokens;
|
||||
size_t context_size;
|
||||
std::vector<int> context() const {
|
||||
return std::vector<int>(tokens.begin(), tokens.begin() + context_size);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class PromptSampler {
|
||||
public:
|
||||
virtual Prompt Sample(std::mt19937& gen) = 0;
|
||||
virtual ~PromptSampler() = default;
|
||||
|
||||
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
|
||||
std::vector<Prompt> batch;
|
||||
batch.reserve(batch_size);
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch.emplace_back(Sample(gen));
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
};
|
||||
|
||||
class ReverseSequenceSampler : public PromptSampler {
|
||||
public:
|
||||
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
|
||||
: token_dist_(0, 9) {
|
||||
for (int i = 0; i < length_histo.size(); ++i) {
|
||||
const int count = length_histo[i];
|
||||
for (int j = 0; j < count; ++j) {
|
||||
length_lut_.push_back(i + 1);
|
||||
}
|
||||
}
|
||||
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
|
||||
}
|
||||
virtual ~ReverseSequenceSampler() = default;
|
||||
|
||||
static constexpr int kReverseToken = 10;
|
||||
static constexpr int kEndToken = 11;
|
||||
|
||||
Prompt Sample(std::mt19937& gen) override {
|
||||
Prompt prompt;
|
||||
int len = length_lut_[length_dist_(gen)];
|
||||
prompt.tokens.resize(2 * len + 2);
|
||||
prompt.tokens[len] = kReverseToken;
|
||||
prompt.tokens[2 * len + 1] = kEndToken;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
|
||||
}
|
||||
prompt.context_size = len + 1;
|
||||
return prompt;
|
||||
}
|
||||
|
||||
static void LogPrompt(const Prompt& prompt) {
|
||||
static const char* kVocab[] = {
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
|
||||
};
|
||||
for (int token : prompt.tokens) printf("%s", kVocab[token]);
|
||||
printf(" [context_size: %zu]\n", prompt.context_size);
|
||||
}
|
||||
|
||||
private:
|
||||
std::uniform_int_distribution<> token_dist_;
|
||||
std::uniform_int_distribution<> length_dist_;
|
||||
std::vector<int> length_lut_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
const T* row = x.Row(r);
|
||||
std::complex<U>* c_row = c_x.Row(r);
|
||||
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||
c_row[c] = std::complex<U>(row[c], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const LayerWeightsPtrs<T>& w, LayerWeightsPtrs<U>& c_w) {
|
||||
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
|
||||
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
|
||||
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
|
||||
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
|
||||
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
|
||||
Complexify(w.linear_w, c_w.linear_w);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
|
||||
const size_t kLayers = w.c_layers.size();
|
||||
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
|
||||
Complexify(w.final_norm_scale, c_w.final_norm_scale);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Somewhat duplicates `WeightsOwner`, but that has neither double nor
|
||||
// complex types allowed and it would cause code bloat to add them there.
|
||||
template <typename T>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
explicit WeightsWrapper(const ModelConfig& config) : weights_(config) {
|
||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
||||
weights_.AllocateForTest(owners_, pool);
|
||||
}
|
||||
|
||||
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
||||
ModelWeightsPtrs<T>& get() { return weights_; }
|
||||
|
||||
private:
|
||||
std::vector<MatOwner> owners_;
|
||||
ModelWeightsPtrs<T> weights_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||
double max_abs_err, double max_rel_err, int line_test,
|
||||
int line_util) {
|
||||
// TODO: consider compensated sum.
|
||||
double sum0 = 0;
|
||||
double sum1 = 0;
|
||||
double sum01 = 0;
|
||||
for (size_t r = 0; r < actual.Rows(); ++r) {
|
||||
const T* actual_row = actual.Row(r);
|
||||
const U* expected_row = expected.Row(r);
|
||||
for (size_t c = 0; c < actual.Cols(); ++c) {
|
||||
sum0 += actual_row[c] * actual_row[c];
|
||||
sum1 += expected_row[c] * expected_row[c];
|
||||
sum01 += actual_row[c] * expected_row[c];
|
||||
ASSERT_NEAR(
|
||||
actual_row[c], expected_row[c],
|
||||
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
|
||||
<< "test line " << line_test << "test_util.h line " << line_util
|
||||
<< " r " << r << " c " << c;
|
||||
}
|
||||
}
|
||||
if (sum0 > 1e-16) {
|
||||
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
||||
ASSERT_NEAR(norm_dot, 1.0, 3e-6)
|
||||
<< "test line " << line_test << " test_util.h line " << line_util
|
||||
<< " sum0: " << sum0 << " sum1: " << sum1 << " sum01: " << sum01;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute gradient with the finite difference method in the complex plane.
|
||||
// If f : R->R is the tested function and F : C->C is its extension on the
|
||||
// complex plane so that F is complex differentiable in x, then
|
||||
//
|
||||
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
|
||||
//
|
||||
// which means that
|
||||
//
|
||||
// F'(x) ~= Imag(F(x + ih)) / h
|
||||
//
|
||||
// This method is more numerically stable than the real-valued finite difference
|
||||
// method since we don't need to subtract floating point numbers that are near
|
||||
// to each other.
|
||||
template <typename FUNC, typename T, typename U>
|
||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
||||
FUNC func, U step, T max_abs_err, T max_rel_err,
|
||||
int line_test, int line_util) {
|
||||
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
|
||||
const U inv_step = 1.0 / step;
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
std::complex<U>* x_row = x.Row(r);
|
||||
T* exp_row = exp_grad.Row(r);
|
||||
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||
const U x0 = std::real(x_row[c]);
|
||||
const std::complex<U> x1 = std::complex<U>(x0, step);
|
||||
x_row[c] = x1;
|
||||
const std::complex<U> f1 = func();
|
||||
exp_row[c] = std::imag(f1) * inv_step;
|
||||
x_row[c] = x0;
|
||||
}
|
||||
}
|
||||
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line_test, line_util);
|
||||
}
|
||||
|
||||
template <typename FUNC>
|
||||
void TestGradient(const MatPtrT<float>& grad, MatPtrT<std::complex<float>>& x,
|
||||
FUNC func, float max_abs_err, float max_rel_error,
|
||||
int line_test, int line_util) {
|
||||
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line_test,
|
||||
line_util);
|
||||
}
|
||||
|
||||
template <typename FUNC, typename T>
|
||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
|
||||
FUNC func, T max_abs_err, T max_rel_error, int line_test,
|
||||
int line_util) {
|
||||
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line_test,
|
||||
line_util);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename FUNC>
|
||||
void TestGradient(const LayerWeightsPtrs<T>& grad,
|
||||
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err,
|
||||
int line_test) {
|
||||
TestGradient(grad.pre_attention_norm_scale,
|
||||
c_weights.pre_attention_norm_scale, func, max_err, max_err,
|
||||
line_test, __LINE__);
|
||||
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, func,
|
||||
max_err, max_err, line_test, __LINE__);
|
||||
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, func, max_err,
|
||||
max_err, line_test, __LINE__);
|
||||
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, func,
|
||||
max_err, max_err, line_test, __LINE__);
|
||||
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, func, max_err,
|
||||
max_err, line_test, __LINE__);
|
||||
TestGradient(grad.linear_w, c_weights.linear_w, func, max_err, max_err,
|
||||
line_test, __LINE__);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename FUNC>
|
||||
void TestGradient(const ModelWeightsPtrs<T>& grad,
|
||||
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err,
|
||||
int line_test) {
|
||||
TestGradient(grad.embedder_input_embedding,
|
||||
c_weights.embedder_input_embedding, func, 2 * max_err, max_err,
|
||||
line_test, __LINE__);
|
||||
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, func, max_err,
|
||||
max_err, line_test, __LINE__);
|
||||
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
|
||||
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err,
|
||||
line_test);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
|
|
@ -65,8 +65,8 @@ static constexpr bool kIsTest = false;
|
|||
template <typename T> // primary, must specialize
|
||||
struct CompressTraits {};
|
||||
|
||||
// Used by backprop/, where weights are currently f32; also MatMul for f32
|
||||
// weights or activations, if native `ReorderWidenMulAccumulate` is available.
|
||||
// Used by MatMul for f32 weights or activations, if native
|
||||
// `ReorderWidenMulAccumulate` is available.
|
||||
template <>
|
||||
struct CompressTraits<float> {
|
||||
using Packed = float;
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) {
|
|||
}
|
||||
});
|
||||
|
||||
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
|
||||
Compress(raw.PackedScale1(), raw.Extents().Area(), ws, compressed.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
|
||||
return compressed;
|
||||
|
|
@ -104,7 +104,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
|||
}
|
||||
});
|
||||
|
||||
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
|
||||
Compress(raw.PackedScale1(), raw.Extents().Area(), ws, compressed.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
// Arbitrary value, different from 1, must match `GenerateMat`.
|
||||
compressed.SetScale(0.6f);
|
||||
|
|
|
|||
|
|
@ -21,9 +21,6 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <complex>
|
||||
#include <cstdio>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h" // BF16
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
|
@ -160,24 +157,22 @@ constexpr bool IsNuqStream() {
|
|||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
||||
}
|
||||
|
||||
// Tensor types for loading weights. Note that not all types are supported as
|
||||
// weights for a model, but can be used for other purposes, such as types for
|
||||
// `WeightsPtrs`. When adding a new type that is supported, also
|
||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 };
|
||||
// Tensor types for loading weights.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 };
|
||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||
// change, though new ones may be added.
|
||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "c64"};
|
||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16",
|
||||
"sfp", "nuq", "f64"};
|
||||
static constexpr size_t kNumTypes =
|
||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||
static constexpr size_t kTypeBits[] = {0,
|
||||
8 * sizeof(float),
|
||||
8 * sizeof(BF16),
|
||||
8 * sizeof(SfpStream),
|
||||
4 /* NuqStream, actually 4.5 */,
|
||||
8 * sizeof(double),
|
||||
8 * sizeof(std::complex<double>)};
|
||||
static constexpr size_t kTypeBits[] = {
|
||||
0,
|
||||
8 * sizeof(float),
|
||||
8 * sizeof(BF16),
|
||||
8 * sizeof(SfpStream),
|
||||
4 /* NuqStream, actually 4.5 */,
|
||||
8 * sizeof(double),
|
||||
};
|
||||
|
||||
static inline bool EnumValid(Type type) {
|
||||
return static_cast<size_t>(type) < kNumTypes;
|
||||
|
|
@ -197,8 +192,6 @@ Type TypeEnum() {
|
|||
return Type::kNUQ;
|
||||
} else if constexpr (hwy::IsSame<Packed, double>()) {
|
||||
return Type::kF64;
|
||||
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
|
||||
return Type::kC64;
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
return Type::kUnknown;
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ namespace gcpp {
|
|||
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt);
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
// Also used by backprop/.
|
||||
float EmbeddingScaling(size_t model_dim);
|
||||
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
|
|
|
|||
|
|
@ -152,13 +152,12 @@ static ModelConfig ConfigGemmaTiny() {
|
|||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 32; // at least two f32 vectors
|
||||
config.seq_len = 32; // optimize_test requires more than 24
|
||||
config.seq_len = 32;
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.num_layers = 2;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
config.eos_id = 11;
|
||||
|
|
@ -203,7 +202,6 @@ static ModelConfig ConfigGriffin2B() {
|
|||
}
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
|
||||
config.use_local_attention = true;
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 0.0f;
|
||||
return config;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ enum class Model {
|
|||
GEMMA2_9B = 3,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY, // for backprop/ only
|
||||
GEMMA_TINY, // for testing only
|
||||
GEMMA2_2B,
|
||||
// 8 and 9 are obsolete.
|
||||
PALIGEMMA2_3B_224 = 10,
|
||||
|
|
@ -330,9 +330,8 @@ struct ModelConfig : public IFields {
|
|||
// from a blob. Also used by `config_converter.py`, which sets sufficient
|
||||
// fields for `TestEqual` and then calls `OverwriteWithCanonical()`.
|
||||
ModelConfig() = default;
|
||||
// For use by `backprop/`, and `model_store.cc` for pre-2025 format after
|
||||
// deducing the model from tensors plus a user-specified `wrapping` override
|
||||
// (see `ChooseWrapping`).
|
||||
// For use by `model_store.cc` for pre-2025 format after deducing the model
|
||||
// from tensors plus a user-specified `wrapping` override (`ChooseWrapping`).
|
||||
ModelConfig(Model model, Type weight, PromptWrapping wrapping);
|
||||
// Parses a string returned by `Specifier()`. Used by the exporter to select
|
||||
// the model from command line arguments. Do not use this elsewhere - the
|
||||
|
|
|
|||
|
|
@ -230,13 +230,13 @@ class GemmaAttention {
|
|||
const float mul) {
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const float* inv_timescale = activations_.inv_timescale.Packed();
|
||||
const float* inv_timescale = activations_.inv_timescale.PackedScale1();
|
||||
bool is_global_layer =
|
||||
activations_.weights_config.attention_window_sizes[layer] ==
|
||||
activations_.seq_len;
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer && IsVLM(activations_.weights_config.model)) {
|
||||
inv_timescale = activations_.inv_timescale_global.Packed();
|
||||
inv_timescale = activations_.inv_timescale_global.PackedScale1();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
(void)layer;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@
|
|||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
|
|
@ -61,16 +60,6 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
|||
reader_.reset();
|
||||
}
|
||||
|
||||
Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer,
|
||||
MatMulEnv& env)
|
||||
: env_(env),
|
||||
model_(config, std::move(tokenizer)),
|
||||
weights_(config.weight),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
HWY_ASSERT(config.weight == Type::kF32);
|
||||
weights_.AllocateForTest(config, env_.ctx.pools.Pool(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() = default;
|
||||
|
||||
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
|
||||
|
|
|
|||
|
|
@ -107,11 +107,6 @@ class Gemma {
|
|||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const LoaderArgs& loader, MatMulEnv& env);
|
||||
|
||||
// Only allocates weights, caller is responsible for filling them. Only used
|
||||
// by `optimize_test.cc`.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer, MatMulEnv& env);
|
||||
|
||||
~Gemma();
|
||||
|
||||
MatMulEnv& Env() const { return env_; }
|
||||
|
|
|
|||
|
|
@ -53,10 +53,7 @@ class ModelStore {
|
|||
// Reads from file(s) or aborts on error. The latter two arguments are only
|
||||
// used for pre-2025 files.
|
||||
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
// For optimize_test.cc.
|
||||
ModelStore(const ModelConfig& config, GemmaTokenizer&& tokenizer)
|
||||
: config_(config), tokenizer_(std::move(tokenizer)) {}
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
~ModelStore();
|
||||
|
||||
const ModelConfig& Config() const {
|
||||
|
|
|
|||
|
|
@ -60,8 +60,7 @@ class GemmaTokenizer {
|
|||
|
||||
class GemmaChatTemplate {
|
||||
public:
|
||||
// No effect if `tokenizer` is unavailable (as happens in optimize_test.cc),
|
||||
// but then any other method may abort.
|
||||
// No effect if `tokenizer` is unavailable, but any other method may abort.
|
||||
GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model);
|
||||
|
||||
// Given prompt tokens, this returns the wrapped prompt including BOS and
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include <stdlib.h>
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -37,7 +36,6 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/stats.h"
|
||||
|
||||
// TODO: move into foreach_target; this is only used for NUQ Fixup.
|
||||
#include "compression/compress-inl.h"
|
||||
|
|
@ -297,9 +295,6 @@ void WeightsOwner::AllocatePointer(const ModelConfig& config) {
|
|||
case Type::kNUQ:
|
||||
nuq_weights_.reset(new ModelWeightsPtrs<NuqStream>(config));
|
||||
break;
|
||||
case Type::kF32:
|
||||
float_weights_.reset(new ModelWeightsPtrs<float>(config));
|
||||
break;
|
||||
case Type::kBF16:
|
||||
bf16_weights_.reset(new ModelWeightsPtrs<BF16>(config));
|
||||
break;
|
||||
|
|
@ -308,55 +303,6 @@ void WeightsOwner::AllocatePointer(const ModelConfig& config) {
|
|||
}
|
||||
}
|
||||
|
||||
// Gemma calls `WeightsOwner::ReadOrAllocate`, but test code instead calls
|
||||
// `WeightsPtrs::AllocateForTest`, so the implementation is there, and here
|
||||
// we only type-dispatch.
|
||||
void WeightsOwner::AllocateForTest(const ModelConfig& config,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.AllocateWeights");
|
||||
|
||||
AllocatePointer(config);
|
||||
CallT([&](const auto& weights) {
|
||||
weights->AllocateForTest(mat_owners_, pool);
|
||||
});
|
||||
}
|
||||
|
||||
void WeightsOwner::ZeroInit() {
|
||||
PROFILER_FUNC;
|
||||
CallT([](const auto& weights) { weights->ZeroInit(); });
|
||||
}
|
||||
|
||||
void WeightsOwner::RandInit(float stddev, std::mt19937& gen) {
|
||||
PROFILER_FUNC;
|
||||
float_weights_->RandInit(stddev, gen);
|
||||
}
|
||||
|
||||
void WeightsOwner::LogWeightStatsF32() {
|
||||
size_t total_weights = 0;
|
||||
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
|
||||
float_weights_->ForEachTensor(
|
||||
nullptr, nullptr, [&total_weights](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
if (t.mat.Scale() != 1.0f) {
|
||||
printf("[scale=%f] ", t.mat.Scale());
|
||||
}
|
||||
hwy::Stats stats;
|
||||
const MatPtrT<float> mat_f(t.mat);
|
||||
for (size_t r = 0; r < t.mat.Rows(); ++r) {
|
||||
const float* HWY_RESTRICT row = mat_f.Row(r);
|
||||
for (size_t c = 0; c < t.mat.Cols(); ++c) {
|
||||
stats.Notify(row[c]);
|
||||
}
|
||||
}
|
||||
printf("%-20s %12zu %13.10f %8.5f %13.10f\n", t.mat.Name(),
|
||||
t.mat.Rows() * t.mat.Cols(), stats.Min(), stats.Mean(),
|
||||
stats.Max());
|
||||
|
||||
total_weights += t.mat.Rows() * t.mat.Cols();
|
||||
});
|
||||
printf("%-20s %12zu\n", "Total", total_weights);
|
||||
}
|
||||
|
||||
void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.Fixup");
|
||||
CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
#include <complex>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -298,23 +297,6 @@ struct LayerWeightsPtrs {
|
|||
});
|
||||
}
|
||||
|
||||
void RandInit(float stddev, std::mt19937& gen) {
|
||||
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
gcpp::RandInit(t.mat, stddev, gen);
|
||||
});
|
||||
}
|
||||
|
||||
// Allocates memory for all the tensors in the layer. Note that this is slow
|
||||
// (non-parallel) and only used for a stand-alone layer.
|
||||
void AllocateForTest(std::vector<MatOwner>& mat_owners) {
|
||||
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
// `backprop/` does not use row accessors and hence requires kPacked.
|
||||
mat_owners.push_back(MatOwner());
|
||||
mat_owners.back().AllocateFor(t.mat, MatPadding::kPacked);
|
||||
});
|
||||
}
|
||||
|
||||
// Must be called after reading weights via `ForEachTensor`.
|
||||
// TODO: exporters should bake this into the weights already.
|
||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||
|
|
@ -330,8 +312,8 @@ struct LayerWeightsPtrs {
|
|||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
// Files must have one or the other, and backprop/ allocates both.
|
||||
HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr());
|
||||
// Files must have one or the other.
|
||||
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
|
||||
// Done if we already read the transposed tensor.
|
||||
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
|
||||
|
||||
|
|
@ -373,11 +355,10 @@ struct LayerWeightsPtrs {
|
|||
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
||||
if (layer_config.type == LayerAttentionType::kVit) return;
|
||||
|
||||
// Files have both or neither of w1 and w2, and backprop/ allocates both.
|
||||
// Files have both or neither of w1 and w2.
|
||||
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||
// w is mutually exclusive with w1 and w2 in the file, but backprop/
|
||||
// allocates both, so we can only rule out both being null.
|
||||
HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr());
|
||||
// w is mutually exclusive with w1 and w2 in the file.
|
||||
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors. Note that they are not
|
||||
// necessarily the same type.
|
||||
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
|
||||
|
|
@ -397,7 +378,7 @@ struct LayerWeightsPtrs {
|
|||
gating_einsum_w2.SetType(gating_einsum_w.GetType());
|
||||
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
|
||||
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
|
||||
// Do not invalidate gating_einsum_w: backprop/ calls this repeatedly.
|
||||
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
|
||||
}
|
||||
|
||||
// For attention, which might not have a w2. Fast, only updates pointers.
|
||||
|
|
@ -405,9 +386,8 @@ struct LayerWeightsPtrs {
|
|||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
// w is mutually exclusive with w1 in the file, but backprop/ allocates
|
||||
// both, so we can only rule out both being null.
|
||||
HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr());
|
||||
// w is mutually exclusive with w1 in the file.
|
||||
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors. Note that w2 does not exist for
|
||||
// MHA, and otherwise might not be the same type.
|
||||
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
||||
|
|
@ -436,7 +416,7 @@ struct LayerWeightsPtrs {
|
|||
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
|
||||
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
|
||||
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
|
||||
// Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly.
|
||||
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -567,13 +547,6 @@ struct ModelWeightsPtrs {
|
|||
});
|
||||
}
|
||||
|
||||
void RandInit(float stddev, std::mt19937& gen) {
|
||||
ForEachTensor(nullptr, nullptr, [stddev, &gen](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
gcpp::RandInit(t.mat, stddev, gen);
|
||||
});
|
||||
}
|
||||
|
||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
||||
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
|
||||
|
|
@ -584,27 +557,6 @@ struct ModelWeightsPtrs {
|
|||
});
|
||||
}
|
||||
|
||||
// Instead of reading, only allocates memory for all tensors. Used by
|
||||
// `optimizer.cc` via the `Gemma` constructor without weights.
|
||||
void AllocateForTest(std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
// First get a list of all the tensors.
|
||||
std::vector<MatPtr*> all_mat;
|
||||
all_mat.reserve(10 * c_layers.size());
|
||||
ForEachTensor(nullptr, nullptr, [&all_mat](const TensorArgs& t) {
|
||||
all_mat.push_back(&t.mat);
|
||||
});
|
||||
|
||||
const size_t start = mat_owners.size();
|
||||
mat_owners.resize(start + all_mat.size());
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
pool.Run(0, all_mat.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
// `backprop/` does not use row accessors and hence requires kPacked.
|
||||
mat_owners[start + task].AllocateFor(*all_mat[task], MatPadding::kPacked);
|
||||
});
|
||||
}
|
||||
|
||||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Must be called after reading and
|
||||
// updating the attention weights.
|
||||
|
|
@ -640,8 +592,6 @@ class WeightsOwner {
|
|||
return func(sfp_weights_, std::forward<TArgs>(args)...);
|
||||
} else if (weight_type_ == Type::kNUQ) {
|
||||
return func(nuq_weights_, std::forward<TArgs>(args)...);
|
||||
} else if (weight_type_ == Type::kF32) {
|
||||
return func(float_weights_, std::forward<TArgs>(args)...);
|
||||
} else if (weight_type_ == Type::kBF16) {
|
||||
return func(bf16_weights_, std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
|
@ -653,23 +603,6 @@ class WeightsOwner {
|
|||
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
||||
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
||||
|
||||
// For backprop/:
|
||||
|
||||
// Only allocates; must subsequently call `ZeroInit` or `RandInit`.
|
||||
void AllocateForTest(const ModelConfig& config, hwy::ThreadPool& pool);
|
||||
void ZeroInit();
|
||||
void RandInit(float stddev, std::mt19937& gen); // F32 or F64 only.
|
||||
void LogWeightStatsF32();
|
||||
|
||||
ModelWeightsPtrs<float>* GetF32() const {
|
||||
HWY_ASSERT(weight_type_ == Type::kF32);
|
||||
return float_weights_.get();
|
||||
}
|
||||
|
||||
// Usually taken care of by `ReadFromBlobs`, but must also be called by
|
||||
// `optimize_test, which updates the attention weights from which this copies.
|
||||
void Fixup(hwy::ThreadPool& pool);
|
||||
|
||||
private:
|
||||
Type weight_type_;
|
||||
|
||||
|
|
@ -677,8 +610,10 @@ class WeightsOwner {
|
|||
// of `CallT` so that can be const.
|
||||
void AllocatePointer(const ModelConfig& config);
|
||||
|
||||
// Called by `ReadFromBlobs`.
|
||||
void Fixup(hwy::ThreadPool& pool);
|
||||
|
||||
// Only one is non-null, determined by `weight_type_`.
|
||||
std::unique_ptr<ModelWeightsPtrs<float>> float_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
||||
|
|
|
|||
|
|
@ -1018,14 +1018,14 @@ struct TestShortDotsT {
|
|||
const PackedSpan<T> v = vectors.Span();
|
||||
|
||||
MatStorageT<double> bufs("bufs", num);
|
||||
double* HWY_RESTRICT buf = bufs.Packed();
|
||||
double* HWY_RESTRICT buf = bufs.Row(0);
|
||||
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work);
|
||||
GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work);
|
||||
GenerateWellConditionedInputs(num, raw_w.Row(0), rng, w, work);
|
||||
GenerateWellConditionedInputs(num, raw_v.Row(0), rng, v, work);
|
||||
|
||||
const float dot_exact =
|
||||
ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf);
|
||||
ExactDot(raw_w.PackedScale1(), raw_v.PackedScale1(), num, buf);
|
||||
float dots[kVariants];
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
// Here Packed is not always float, so we must not call kDouble.
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
*/
|
||||
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||
// This overload is called from backprop/ and if kUseHalfRope.
|
||||
// This overload is called if kUseHalfRope.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
|
|||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||
// noticeably.
|
||||
inv_timescale.Packed()[dim] =
|
||||
inv_timescale.Row(0)[dim] =
|
||||
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
|
||||
}
|
||||
return inv_timescale;
|
||||
|
|
|
|||
|
|
@ -399,7 +399,7 @@ void TestRopeAndMulBy() {
|
|||
auto random_float = [&r, &gen] { return r(gen); };
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
x.Packed()[i] = random_float();
|
||||
x.Row(0)[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = ChooseQueryScale(config);
|
||||
|
|
@ -417,25 +417,25 @@ void TestRopeAndMulBy() {
|
|||
// Rope'd Q embeddings
|
||||
CopyMat(x, qactual);
|
||||
CopyMat(x, qexpected);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv,
|
||||
inv_timescale.Packed(), pos);
|
||||
RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4)
|
||||
<< "qIndex:" << i << "qInput:" << qactual.Packed()[i];
|
||||
EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4)
|
||||
<< "qIndex:" << i << "qInput:" << qactual.Row(0)[i];
|
||||
}
|
||||
|
||||
// Rope'd K embeddings
|
||||
CopyMat(x, kactual);
|
||||
CopyMat(x, kexpected);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv,
|
||||
inv_timescale.Packed(), pos);
|
||||
RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4)
|
||||
<< "kIndex:" << i << "kInput:" << kactual.Packed()[i];
|
||||
EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4)
|
||||
<< "kIndex:" << i << "kInput:" << kactual.Row(0)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,9 +53,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.value("kF32", Type::kF32)
|
||||
.value("kBF16", Type::kBF16)
|
||||
.value("kSFP", Type::kSFP)
|
||||
.value("kNUQ", Type::kNUQ)
|
||||
.value("kF64", Type::kF64)
|
||||
.value("kC64", Type::kC64);
|
||||
.value("kNUQ", Type::kNUQ);
|
||||
|
||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
||||
.value("kGemma", LayerAttentionType::kGemma)
|
||||
|
|
|
|||
31
util/mat.cc
31
util/mat.cc
|
|
@ -18,8 +18,6 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
|
|
@ -62,35 +60,6 @@ void ZeroInit(MatPtr& mat) {
|
|||
}
|
||||
}
|
||||
|
||||
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
|
||||
PROFILER_FUNC;
|
||||
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
|
||||
// Only generates float/double for use by backprop/.
|
||||
HWY_ASSERT(mat.GetType() == Type::kF32 || mat.GetType() == Type::kF64);
|
||||
mat.SetScale(1.0f);
|
||||
|
||||
std::normal_distribution<float> dist(0.0, stddev);
|
||||
if (mat.GetType() == Type::kF32) {
|
||||
MatPtrT<float> mat_f(mat);
|
||||
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
float* HWY_RESTRICT row = mat_f.Row(r);
|
||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||
row[c] = dist(gen);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MatPtrT<double> mat_d(mat);
|
||||
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
double* HWY_RESTRICT row = mat_d.Row(r);
|
||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||
row[c] = dist(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
|
||||
size_t line_bytes) {
|
||||
switch (padding) {
|
||||
|
|
|
|||
40
util/mat.h
40
util/mat.h
|
|
@ -20,7 +20,6 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
|
|
@ -229,21 +228,11 @@ class MatPtrT : public MatPtr {
|
|||
MatPtrT(const MatPtrT& other) = default;
|
||||
MatPtrT& operator=(const MatPtrT& other) = default;
|
||||
|
||||
// Returns the entire tensor for use by `backprop/*`. Verifies layout is
|
||||
// `kPacked`. Preferably call `Row` instead, which works for either layout.
|
||||
MatT* Packed() {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
return HWY_RCAST_ALIGNED(MatT*, ptr_);
|
||||
}
|
||||
const MatT* Packed() const {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
||||
}
|
||||
// As `Packed()`, plus checks the scale is 1.0 because callers will ignore it.
|
||||
// This is typically used for `MatMul` bias vectors and norm weights.
|
||||
// Returns the entire tensor after checking the scale is 1.0 because callers
|
||||
// will ignore it. Used for `MatMul` bias vectors and norm weights.
|
||||
const MatT* PackedScale1() const {
|
||||
HWY_DASSERT(Scale() == 1.0f);
|
||||
return Packed();
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
||||
}
|
||||
|
||||
MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
|
||||
|
|
@ -268,8 +257,7 @@ class MatPtrT : public MatPtr {
|
|||
};
|
||||
|
||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||
// optional `args`. This supports all types used as weights, which excludes
|
||||
// `kC64` and `kF64` (used only in `backprop/`).
|
||||
// optional `args`. This supports all types used as weights.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
|
|
@ -334,8 +322,6 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
|||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||
void ZeroInit(MatPtr& mat);
|
||||
// F32/F64 only.
|
||||
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
|
||||
|
||||
// Our tensors are always row-major. This enum indicates how much (if any)
|
||||
// padding comes after each row.
|
||||
|
|
@ -346,9 +332,6 @@ enum class MatPadding {
|
|||
// `BlobStore`) are not padded, which also extends to memory-mapped tensors.
|
||||
// However, `BlobStore` is able to insert padding via row-wise I/O when
|
||||
// reading from disk via `Mode::kRead`.
|
||||
//
|
||||
// `backprop/*` also requires this layout because it indexes directly into
|
||||
// the storage instead of calling `Row()`.
|
||||
kPacked,
|
||||
// Enough to round up to an odd number of cache lines, which can reduce
|
||||
// cache conflict misses or 4K aliasing.
|
||||
|
|
@ -378,9 +361,9 @@ class MatOwner {
|
|||
AlignedPtr<uint8_t[]> storage_;
|
||||
};
|
||||
|
||||
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
|
||||
// tests to allocate and access tensors of a known type. By contrast, the
|
||||
// heterogeneous model weights are owned by vectors of `MatOwner`.
|
||||
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by tests to allocate
|
||||
// and access tensors of a known type. By contrast, the heterogeneous model
|
||||
// weights are owned by vectors of `MatOwner`.
|
||||
template <typename MatT>
|
||||
class MatStorageT : public MatPtrT<MatT> {
|
||||
public:
|
||||
|
|
@ -393,7 +376,7 @@ class MatStorageT : public MatPtrT<MatT> {
|
|||
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {}
|
||||
~MatStorageT() = default;
|
||||
|
||||
// Allow move for backprop/activations.
|
||||
// Allow move for KVCache.
|
||||
MatStorageT(MatStorageT&&) = default;
|
||||
MatStorageT& operator=(MatStorageT&&) = default;
|
||||
|
||||
|
|
@ -401,13 +384,6 @@ class MatStorageT : public MatPtrT<MatT> {
|
|||
MatOwner owner_;
|
||||
};
|
||||
|
||||
// Helper factory function for use by `backprop/` to avoid specifying the
|
||||
// `MatPadding` argument everywhere.
|
||||
template <typename T>
|
||||
MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
|
||||
return MatStorageT<T>(name, Extents2D(rows, cols), MatPadding::kPacked);
|
||||
}
|
||||
|
||||
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
|
||||
// seekable (non-NUQ) T.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
|
|
|
|||
Loading…
Reference in New Issue