mirror of https://github.com/google/gemma.cpp.git
219 lines
8.2 KiB
C++
219 lines
8.2 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Transformer components shared between vit.cc and attention.cc.
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include "gemma/activations.h"
|
|
#include "gemma/configs.h"
|
|
#include "gemma/tensor_stats.h"
|
|
#include "gemma/weights.h"
|
|
#include "ops/matmul.h"
|
|
#include "util/mat.h"
|
|
#include "util/threading.h"
|
|
#include "util/zones.h"
|
|
#include "hwy/profiler.h"
|
|
|
|
// Include guard (still compiled once per target)
|
|
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \
|
|
defined(HWY_TARGET_TOGGLE)
|
|
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
|
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
|
#else
|
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
|
#endif
|
|
|
|
#include "hwy/highway.h"
|
|
// After highway.h
|
|
#include "ops/ops-inl.h"
|
|
|
|
HWY_BEFORE_NAMESPACE();
|
|
namespace gcpp {
|
|
namespace HWY_NAMESPACE {
|
|
|
|
// For use by Vit even if !GEMMA_FUSED_FFN.
|
|
template <typename T1, typename T2>
|
|
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
|
const T2* HWY_RESTRICT c2, const size_t count,
|
|
ThreadingContext& ctx, const size_t worker) {
|
|
GCPP_ZONE(ctx, worker, Zones::kGenActivation);
|
|
namespace hn = hwy::HWY_NAMESPACE;
|
|
using DF = hn::ScalableTag<float>;
|
|
using VF = hn::Vec<DF>;
|
|
// ActivationType::Gelu
|
|
if (c2 == nullptr) { // No multiplier, just Gelu.
|
|
Gelu(c1, count);
|
|
return;
|
|
};
|
|
// Has multiplier, Gelu(c1) * c2.
|
|
Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0,
|
|
[](DF df, VF v1, VF v2) HWY_ATTR -> VF {
|
|
return hn::Mul(v2, Gelu(df, v1));
|
|
});
|
|
}
|
|
|
|
// No C2 multiplier - used by Vit.
|
|
template <class Mat>
|
|
void ActivationBatched(
|
|
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
|
size_t cluster_idx = 0,
|
|
Parallelism parallelism = Parallelism::kFlat) {
|
|
using T = typename Mat::T;
|
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
|
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
|
// Cast to correct type so type deduction works.
|
|
Activation(activation, c1.Row(task),
|
|
static_cast<const T*>(nullptr), c1.Cols(), ctx,
|
|
worker);
|
|
});
|
|
}
|
|
|
|
#if GEMMA_FUSED_FFN
|
|
|
|
// Called during `TwoMatMul`.
|
|
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
|
|
const IndexRange range_r,
|
|
const IndexRange range_c, const StridedViewBF C2,
|
|
ThreadingContext& ctx, const size_t worker) {
|
|
GCPP_ZONE(ctx, worker, Zones::kGenActivationFused);
|
|
|
|
const size_t cols = range_c.Num();
|
|
HWY_DASSERT(C2.Cols() == cols);
|
|
|
|
namespace hn = hwy::HWY_NAMESPACE;
|
|
using DF = hn::ScalableTag<float>;
|
|
using VF = hn::Vec<DF>;
|
|
// ActivationType::Gelu
|
|
// Gated: Gelu(c1) * c2.
|
|
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
|
Decompress1AndCompressInplace(
|
|
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
|
|
/*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF {
|
|
return hn::Mul(v2, Gelu(df, v1));
|
|
});
|
|
}
|
|
}
|
|
|
|
#endif // GEMMA_FUSED_FFN
|
|
|
|
// Only used if !GEMMA_FUSED_FFN, but define anyway so that we can check
|
|
// using if constexpr rather than #if, which interferes with code folding.
|
|
template <class Mat1, class Mat2>
|
|
HWY_NOINLINE void ActivationBatched(
|
|
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
|
size_t cluster_idx = 0,
|
|
Parallelism parallelism = Parallelism::kFlat) {
|
|
HWY_DASSERT(c1.SameShape(*c2));
|
|
if (c2 && c2->HasPtr()) {
|
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
|
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
|
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
|
ctx, worker);
|
|
});
|
|
} else { // No multiplier
|
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
|
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
|
Activation(activation, c1.Row(task),
|
|
static_cast<const typename Mat2::T*>(nullptr),
|
|
c1.Cols(), ctx, worker);
|
|
});
|
|
}
|
|
}
|
|
|
|
template <typename T2, class LayerWeights>
|
|
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
|
MatPtrT<float>& HWY_RESTRICT x,
|
|
const LayerWeights& layer,
|
|
bool is_attention, ThreadingContext& ctx) {
|
|
// ResidualType::Add
|
|
AddFromBatched(other, x, ctx);
|
|
}
|
|
|
|
template <typename InOutT>
|
|
void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
|
MatPtrT<InOutT>& inout, ThreadingContext& ctx) {
|
|
HWY_DASSERT(weights.Rows() == 1);
|
|
if (post_norm == PostNormType::Scale) {
|
|
RMSNormInplaceBatched(weights, inout, ctx);
|
|
}
|
|
}
|
|
|
|
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|
Activations& activations, MatMulEnv& env) {
|
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenFFW);
|
|
const LayerConfig& layer_config = layer.layer_config;
|
|
|
|
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
|
|
|
|
activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out,
|
|
env.ctx);
|
|
|
|
#if GEMMA_FUSED_FFN
|
|
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
|
StridedViewBF C2, size_t worker) {
|
|
Activation(layer_config.activation, C1, range_r, range_c, C2, env.ctx,
|
|
worker);
|
|
};
|
|
MMOptions options;
|
|
options.SetFunc(fused);
|
|
CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1,
|
|
layer.gating_einsum_w2, env, activations.C1, options);
|
|
#else
|
|
// Compute the hidden layer activations.
|
|
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env,
|
|
activations.C1);
|
|
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env,
|
|
activations.C2);
|
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
|
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
|
env.ctx);
|
|
#endif
|
|
|
|
activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx);
|
|
|
|
// Hidden layer -> output layer.
|
|
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
|
|
|
activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx);
|
|
}
|
|
|
|
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
|
// head_dim (`qkv_dim`) into output (`layer_out`).
|
|
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|
AttentionActivationsPtrs& activations,
|
|
MatMulEnv& env) {
|
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
|
const LayerConfig& layer_config = layer.layer_config;
|
|
(void)layer_config; // For HWY_DASSERT
|
|
// att_weights and att_out are concatenated heads, each of length
|
|
// layer_config.qkv_dim. Thus the [num_interleaved,
|
|
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
|
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
|
// encoded)
|
|
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
|
layer_config.qkv_dim != 0);
|
|
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
|
activations.att_sums);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
|
} // namespace HWY_NAMESPACE
|
|
} // namespace gcpp
|
|
HWY_AFTER_NAMESPACE();
|
|
|
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|