mirror of https://github.com/google/gemma.cpp.git
Decouple MatMul from gemma-inl: precompile for all input types
Call MatMulStatic instead of MatMul. Also fix build error due to Highway's Lanes not being constexpr. PiperOrigin-RevId: 763777269
This commit is contained in:
parent
421a2ab8ac
commit
627cc04db9
33
BUILD.bazel
33
BUILD.bazel
|
|
@ -295,6 +295,35 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "matmul",
|
||||||
|
srcs = [
|
||||||
|
# single-file build time is ~30sec for msan, hence shard.
|
||||||
|
"ops/matmul_static_bf16.cc",
|
||||||
|
"ops/matmul_static_f32.cc",
|
||||||
|
"ops/matmul_static_nuq.cc",
|
||||||
|
"ops/matmul_static_sfp.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"ops/matmul_static.h",
|
||||||
|
],
|
||||||
|
textual_hdrs = [
|
||||||
|
"ops/matmul_static-inl.h",
|
||||||
|
"ops/matmul-inl.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":basics",
|
||||||
|
":ops",
|
||||||
|
":threading_context",
|
||||||
|
"//compression:compress",
|
||||||
|
"//compression:types",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:profiler",
|
||||||
|
"@highway//:timer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "dot_test",
|
name = "dot_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
@ -373,6 +402,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
|
":matmul",
|
||||||
":ops",
|
":ops",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
|
@ -380,6 +410,7 @@ cc_test(
|
||||||
"//compression:test_util",
|
"//compression:test_util",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:nanobenchmark",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -463,6 +494,7 @@ cc_library(
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":mat",
|
":mat",
|
||||||
|
":matmul",
|
||||||
":model_store",
|
":model_store",
|
||||||
":ops",
|
":ops",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
|
|
@ -578,6 +610,7 @@ cc_test(
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:profiler",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6 EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
|
|
||||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||||
|
|
@ -95,6 +95,10 @@ set(SOURCES
|
||||||
io/io.cc
|
io/io.cc
|
||||||
io/io.h
|
io/io.h
|
||||||
ops/dot-inl.h
|
ops/dot-inl.h
|
||||||
|
ops/matmul_static_bf16.cc
|
||||||
|
ops/matmul_static_f32.cc
|
||||||
|
ops/matmul_static_nuq.cc
|
||||||
|
ops/matmul_static_sfp.cc
|
||||||
ops/matmul-inl.h
|
ops/matmul-inl.h
|
||||||
ops/matmul.cc
|
ops/matmul.cc
|
||||||
ops/matmul.h
|
ops/matmul.h
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||||
# Require a more recent version.
|
# Require a more recent version.
|
||||||
git_override(
|
git_override(
|
||||||
module_name = "highway",
|
module_name = "highway",
|
||||||
commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06",
|
commit = "12d9fa908e0c1d3346c298d472584687a24e4ce6",
|
||||||
remote = "https://github.com/google/highway",
|
remote = "https://github.com/google/highway",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ project(hello_world)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ project(simplified_gemma)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "ops/matmul_static.h"
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
|
@ -48,7 +49,6 @@
|
||||||
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "ops/matmul-inl.h"
|
|
||||||
#include "ops/matvec-inl.h"
|
#include "ops/matvec-inl.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
|
|
@ -266,8 +266,9 @@ class GemmaAttention {
|
||||||
// computed in the second MatMul.
|
// computed in the second MatMul.
|
||||||
const size_t w1_rows = heads * layer_config_.QStride();
|
const size_t w1_rows = heads * layer_config_.QStride();
|
||||||
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
|
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
|
||||||
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
|
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
|
||||||
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
|
/*add=*/nullptr, *activations_.env,
|
||||||
|
RowPtrFromMat(activations_.q));
|
||||||
|
|
||||||
if (is_mha_) {
|
if (is_mha_) {
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||||
|
|
@ -285,7 +286,7 @@ class GemmaAttention {
|
||||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||||
kv_rows.SetStride(cache_pos_size_);
|
kv_rows.SetStride(cache_pos_size_);
|
||||||
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
||||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||||
} else {
|
} else {
|
||||||
// Proceed row by row because there will be wraparound.
|
// Proceed row by row because there will be wraparound.
|
||||||
|
|
@ -489,7 +490,7 @@ class GemmaAttention {
|
||||||
layer_weights_.layer_config.softmax_attn_output_biases
|
layer_weights_.layer_config.softmax_attn_output_biases
|
||||||
? layer_weights_.attention_output_biases.PackedScale1()
|
? layer_weights_.attention_output_biases.PackedScale1()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
MatMul(activations_.att_out, layer_weights_.att_weights, add,
|
MatMulStatic(activations_.att_out, layer_weights_.att_weights, add,
|
||||||
*activations_.env, RowPtrFromMat(activations_.att_sums));
|
*activations_.env, RowPtrFromMat(activations_.att_sums));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -625,9 +626,9 @@ class VitAttention {
|
||||||
auto& qkv = activations_.q;
|
auto& qkv = activations_.q;
|
||||||
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
||||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||||
MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
||||||
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
|
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
|
||||||
RowPtrFromMat(qkv));
|
*activations_.env, RowPtrFromMat(qkv));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(philculliton): transition fully to MatMul.
|
// TODO(philculliton): transition fully to MatMul.
|
||||||
|
|
@ -667,7 +668,7 @@ class VitAttention {
|
||||||
});
|
});
|
||||||
|
|
||||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||||
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
|
MatMulStatic(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
|
||||||
|
|
||||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
float* HWY_RESTRICT c = C.Row(task);
|
float* HWY_RESTRICT c = C.Row(task);
|
||||||
|
|
@ -734,7 +735,7 @@ class VitAttention {
|
||||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||||
// matmul output is the sum over heads.
|
// matmul output is the sum over heads.
|
||||||
auto att_sums = RowPtrFromMat(activations_.att_sums);
|
auto att_sums = RowPtrFromMat(activations_.att_sums);
|
||||||
MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
||||||
*activations_.env, att_sums);
|
*activations_.env, att_sums);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -826,18 +827,18 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
||||||
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||||
|
|
||||||
// Compute the hidden layer activations.
|
// Compute the hidden layer activations.
|
||||||
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1,
|
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
|
||||||
*activations.env, RowPtrFromMat(activations.C1));
|
bias1, *activations.env, RowPtrFromMat(activations.C1));
|
||||||
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2,
|
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
|
||||||
*activations.env, RowPtrFromMat(activations.C2));
|
bias2, *activations.env, RowPtrFromMat(activations.C2));
|
||||||
|
|
||||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
|
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
|
||||||
&activations.C2);
|
&activations.C2);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
|
MatMulStatic(activations.C1, layer_weights->linear_w, output_bias,
|
||||||
RowPtrFromMat(activations.ffw_out));
|
*activations.env, RowPtrFromMat(activations.ffw_out));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||||
|
|
@ -854,14 +855,14 @@ HWY_NOINLINE void FFWVit(Activations& activations,
|
||||||
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
||||||
|
|
||||||
// Compute the hidden layer activations.
|
// Compute the hidden layer activations.
|
||||||
MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
|
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w,
|
||||||
*activations.env, RowPtrFromMat(activations.C1));
|
bias1, *activations.env, RowPtrFromMat(activations.C1));
|
||||||
|
|
||||||
// Activation (Gelu), store in C1.
|
// Activation (Gelu), store in C1.
|
||||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
|
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
||||||
*activations.env, RowPtrFromMat(activations.ffw_out));
|
*activations.env, RowPtrFromMat(activations.ffw_out));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1175,7 +1176,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
|
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
|
||||||
// kPatchSize), MatPadding::kPacked);
|
// kPatchSize), MatPadding::kPacked);
|
||||||
// [Get patches]
|
// [Get patches]
|
||||||
// MatMul(
|
// MatMulStatic(
|
||||||
// MatFromBatch(kVitSeqLen, image_patches),
|
// MatFromBatch(kVitSeqLen, image_patches),
|
||||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||||
|
|
@ -1226,7 +1227,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||||
MatMul(activations.x, weights.vit_img_head_kernel,
|
MatMulStatic(activations.x, weights.vit_img_head_kernel,
|
||||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||||
RowPtrFromMat(image_tokens));
|
RowPtrFromMat(image_tokens));
|
||||||
}
|
}
|
||||||
|
|
@ -1366,7 +1367,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
MatMul(activations.x, weights.embedder_input_embedding,
|
MatMulStatic(activations.x, weights.embedder_input_embedding,
|
||||||
/*add=*/nullptr, *activations.env,
|
/*add=*/nullptr, *activations.env,
|
||||||
RowPtrFromMat(activations.logits));
|
RowPtrFromMat(activations.logits));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
|
#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
|
||||||
|
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#define GEMMA_TYPE hwy::bfloat16_t
|
#define GEMMA_TYPE hwy::bfloat16_t
|
||||||
#include "gemma/gemma-inl.h"
|
#include "gemma/gemma-inl.h"
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
|
||||||
using TF = hn::TFromD<DF>;
|
using TF = hn::TFromD<DF>;
|
||||||
// For non-scalable wide vectors, reduce loop iterations below by recursing
|
// For non-scalable wide vectors, reduce loop iterations below by recursing
|
||||||
// once or twice for halves of 256-bit or 512-bit vectors.
|
// once or twice for halves of 256-bit or 512-bit vectors.
|
||||||
if constexpr (!HWY_HAVE_SCALABLE) {
|
if constexpr (HWY_HAVE_CONSTEXPR_LANES) {
|
||||||
if constexpr (hn::Lanes(df) > 16 / sizeof(TF)) {
|
if constexpr (hn::Lanes(df) > 16 / sizeof(TF)) {
|
||||||
const hn::Half<DF> dfh;
|
const hn::Half<DF> dfh;
|
||||||
using VFH = hn::Vec<decltype(dfh)>;
|
using VFH = hn::Vec<decltype(dfh)>;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "hwy/detect_compiler_arch.h" // HWY_IDE
|
||||||
|
|
||||||
|
#ifndef GEMMA_MATMUL_TB
|
||||||
|
#if HWY_IDE
|
||||||
|
// Provide a definition so the IDE does not complain.
|
||||||
|
#define GEMMA_MATMUL_TB float
|
||||||
|
#else
|
||||||
|
#error "Only include from matmul_static_*.cc, which define GEMMA_MATMUL_TB"
|
||||||
|
#endif // HWY_IDE
|
||||||
|
#endif // GEMMA_MATMUL_TB
|
||||||
|
|
||||||
|
// Passed to GEMMA_MATMUL_FOREACH_AC; defines one overload for one target.
|
||||||
|
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
|
||||||
|
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||||
|
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||||
|
const RowPtr<TC>& C) { \
|
||||||
|
return MatMul(A, B, add, env, C); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \
|
||||||
|
defined(HWY_TARGET_TOGGLE)
|
||||||
|
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
|
||||||
|
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
|
||||||
|
#else
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "ops/matmul_static.h" // includes highway.h!
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
// Ignore warning that we are defining a function in a header; this is only
|
||||||
|
// included from matmul_static_*.cc.
|
||||||
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
|
||||||
|
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_
|
||||||
|
|
||||||
|
// Declares overloads of MatMulStatic for all SIMD targets and input types.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "ops/matmul.h" // IWYU pragma: keep, b/420428845
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
|
// Invokes GEMMA_X(TA, TB, TC) for all combinations of F32 or BF16.
|
||||||
|
#define GEMMA_MATMUL_FOREACH_AC(GEMMA_X, TB) \
|
||||||
|
GEMMA_X(float, TB, float) \
|
||||||
|
GEMMA_X(float, TB, BF16) \
|
||||||
|
GEMMA_X(BF16, TB, float) \
|
||||||
|
GEMMA_X(BF16, TB, BF16)
|
||||||
|
|
||||||
|
// Passed to GEMMA_MATMUL_FOREACH_AC; declares one overload for one target.
|
||||||
|
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
|
||||||
|
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||||
|
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||||
|
const RowPtr<TC>& C);
|
||||||
|
|
||||||
|
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
||||||
|
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
||||||
|
namespace NAMESPACE { \
|
||||||
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \
|
||||||
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \
|
||||||
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \
|
||||||
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \
|
||||||
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// MatMul function declarations for each SIMD target. Allows direct call from
|
||||||
|
// the per-target namespace. We may later replace this with dynamic dispatch if
|
||||||
|
// the overhead is acceptable.
|
||||||
|
HWY_VISIT_TARGETS(GEMMA_MATMUL_DECL)
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/matmul_static_bf16.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_MATMUL_TB BF16
|
||||||
|
#include "ops/matmul_static-inl.h"
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/matmul_static_f32.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_MATMUL_TB float
|
||||||
|
#include "ops/matmul_static-inl.h"
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/matmul_static_nuq.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_MATMUL_TB NuqStream
|
||||||
|
#include "ops/matmul_static-inl.h"
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
// which we pass the filename via macro 'argument'.
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/matmul_static_sfp.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_MATMUL_TB SfpStream
|
||||||
|
#include "ops/matmul_static-inl.h"
|
||||||
|
|
@ -18,13 +18,18 @@
|
||||||
#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep
|
#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
||||||
// double-precision support, and older x86 to speed up builds.
|
// double-precision support.
|
||||||
#if HWY_ARCH_ARM_V7
|
#if HWY_ARCH_ARM_V7
|
||||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
|
||||||
#else
|
#else
|
||||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4)
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR)
|
||||||
#endif
|
#endif // HWY_ARCH_ARM_V7
|
||||||
#endif
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
// matmul_static is not built as a test, hence does not define MatMulStatic for
|
||||||
|
// worse-than-baseline targets (to speed up builds), so we skip them here, too.
|
||||||
|
#ifndef HWY_SKIP_NON_BEST_BASELINE
|
||||||
|
#define HWY_SKIP_NON_BEST_BASELINE
|
||||||
|
#endif // HWY_SKIP_NON_BEST_BASELINE
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
@ -35,6 +40,7 @@
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/nanobenchmark.h" // Unpredictable1
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -46,7 +52,7 @@
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "compression/test_util-inl.h"
|
#include "compression/test_util-inl.h"
|
||||||
#include "ops/dot-inl.h"
|
#include "ops/dot-inl.h"
|
||||||
#include "ops/matmul-inl.h"
|
#include "ops/matmul_static.h" // also textual
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -234,7 +240,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
MatMulSlow(a, b_trans, add_row, env, C_slow);
|
MatMulSlow(a, b_trans, add_row, env, C_slow);
|
||||||
// A few reps to get coverage of the various autotuned code paths.
|
// A few reps to get coverage of the various autotuned code paths.
|
||||||
for (size_t rep = 0; rep < 16; ++rep) {
|
for (size_t rep = 0; rep < 16; ++rep) {
|
||||||
MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C);
|
MMPerKey* per_key = MatMulStatic(a, b_trans, add_row, env, C);
|
||||||
AssertClose(a, b_trans, C_slow, C, line);
|
AssertClose(a, b_trans, C_slow, C, line);
|
||||||
if (per_key->autotune.Best()) break;
|
if (per_key->autotune.Best()) break;
|
||||||
}
|
}
|
||||||
|
|
@ -258,12 +264,12 @@ void TestTiny() {
|
||||||
MatMulEnv env(ThreadingContext::Get());
|
MatMulEnv env(ThreadingContext::Get());
|
||||||
NestedPools& pools = env.ctx.pools;
|
NestedPools& pools = env.ctx.pools;
|
||||||
|
|
||||||
#if GEMMA_DISABLE_TOPOLOGY
|
if constexpr (GEMMA_DISABLE_TOPOLOGY) {
|
||||||
if (max_packages == 2) break; // we only have one package
|
if (max_packages == 2) break; // we only have one package
|
||||||
#else
|
} else {
|
||||||
// If less than the limit, we have already tested all num_packages.
|
// If less than the limit, we have already tested all num_packages.
|
||||||
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
|
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
|
||||||
#endif
|
}
|
||||||
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
|
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
|
||||||
env.ctx.topology.TopologyString(), pools.PinString());
|
env.ctx.topology.TopologyString(), pools.PinString());
|
||||||
|
|
||||||
|
|
@ -282,8 +288,10 @@ void TestTiny() {
|
||||||
|
|
||||||
void TestAllMatMul() {
|
void TestAllMatMul() {
|
||||||
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
||||||
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
|
// Add Unpredictable1 to prevent erroneous "unreachable code" warning.
|
||||||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
|
if (hwy::Unpredictable1() == 1 &&
|
||||||
|
(HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
|
||||||
|
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue