mirror of https://github.com/google/gemma.cpp.git
Separate monolithic gemma_lib library into more specific cc_library targets.
Creates new cc_library targets for :attention, :tensor_stats and :activations. Eliminates cyclic dependencies between these libraries. PiperOrigin-RevId: 845238905
This commit is contained in:
parent
baa69dfb78
commit
85e2e8ae7f
98
BUILD.bazel
98
BUILD.bazel
|
|
@ -135,6 +135,8 @@ cc_test(
|
||||||
name = "flash_attention_test",
|
name = "flash_attention_test",
|
||||||
srcs = ["gemma/flash_attention_test.cc"],
|
srcs = ["gemma/flash_attention_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":activations",
|
||||||
|
":attention",
|
||||||
":configs",
|
":configs",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
|
@ -439,6 +441,7 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["ops_tests"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":activations",
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
":configs",
|
":configs",
|
||||||
|
|
@ -537,6 +540,21 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "activations",
|
||||||
|
hdrs = ["gemma/activations.h"],
|
||||||
|
deps = [
|
||||||
|
":basics",
|
||||||
|
":configs",
|
||||||
|
":gemma_args",
|
||||||
|
":kv_cache",
|
||||||
|
":mat",
|
||||||
|
":ops",
|
||||||
|
":tensor_stats",
|
||||||
|
":threading_context",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "query",
|
name = "query",
|
||||||
hdrs = ["gemma/query.h"],
|
hdrs = ["gemma/query.h"],
|
||||||
|
|
@ -573,31 +591,21 @@ cc_test(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "attention",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gemma/attention.cc",
|
"gemma/attention.cc",
|
||||||
"gemma/flash_attention.cc",
|
"gemma/flash_attention.cc",
|
||||||
"gemma/gemma.cc",
|
|
||||||
"gemma/tensor_stats.cc",
|
|
||||||
"gemma/vit.cc",
|
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"gemma/activations.h",
|
|
||||||
"gemma/attention.h",
|
"gemma/attention.h",
|
||||||
"gemma/flash_attention.h",
|
"gemma/flash_attention.h",
|
||||||
"gemma/flash_structs.h",
|
"gemma/flash_structs.h",
|
||||||
"gemma/gemma.h",
|
|
||||||
"gemma/tensor_stats.h",
|
|
||||||
"gemma/vit.h",
|
|
||||||
],
|
],
|
||||||
exec_properties = {
|
|
||||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
|
||||||
"mem": "28g",
|
|
||||||
},
|
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"gemma/gemma-inl.h",
|
"gemma/gemma-inl.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":activations",
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
":configs",
|
":configs",
|
||||||
|
|
@ -606,10 +614,71 @@ cc_library(
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
":matmul",
|
||||||
":matmul_env",
|
":matmul_env",
|
||||||
|
":ops",
|
||||||
|
":query",
|
||||||
|
":tensor_stats",
|
||||||
|
":threading",
|
||||||
|
":threading_context",
|
||||||
|
":weights",
|
||||||
|
":zones",
|
||||||
|
"//compression:compress",
|
||||||
|
"//compression:types",
|
||||||
|
"//io",
|
||||||
|
"@highway//:bit_set",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:nanobenchmark", # timer
|
||||||
|
"@highway//:profiler",
|
||||||
|
"@highway//:thread_pool",
|
||||||
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_stats",
|
||||||
|
srcs = ["gemma/tensor_stats.cc"],
|
||||||
|
hdrs = ["gemma/tensor_stats.h"],
|
||||||
|
deps = [
|
||||||
|
":basics",
|
||||||
|
":mat",
|
||||||
|
":ops",
|
||||||
|
":threading_context",
|
||||||
|
":zones",
|
||||||
|
"//compression:compress",
|
||||||
|
"//io",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:profiler",
|
||||||
|
"@highway//:stats",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gemma_lib",
|
||||||
|
srcs = [
|
||||||
|
"gemma/gemma.cc",
|
||||||
|
"gemma/vit.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"gemma/gemma.h",
|
||||||
|
"gemma/vit.h",
|
||||||
|
],
|
||||||
|
exec_properties = {
|
||||||
|
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||||
|
"mem": "28g",
|
||||||
|
},
|
||||||
|
deps = [
|
||||||
|
":activations",
|
||||||
|
":allocator",
|
||||||
|
":attention",
|
||||||
|
":basics",
|
||||||
|
":configs",
|
||||||
|
":gemma_args",
|
||||||
|
":kv_cache",
|
||||||
|
":mat",
|
||||||
|
":matmul_env",
|
||||||
":model_store",
|
":model_store",
|
||||||
":ops",
|
":ops",
|
||||||
":query",
|
":query",
|
||||||
":threading",
|
":tensor_stats",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":weights",
|
":weights",
|
||||||
":zones",
|
":zones",
|
||||||
|
|
@ -618,7 +687,6 @@ cc_library(
|
||||||
"//io",
|
"//io",
|
||||||
"//io:blob_store",
|
"//io:blob_store",
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@highway//:bit_set",
|
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:nanobenchmark", # timer
|
"@highway//:nanobenchmark", # timer
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
|
|
@ -636,6 +704,8 @@ cc_test(
|
||||||
# MatMulEnvs are up to 20GB large.
|
# MatMulEnvs are up to 20GB large.
|
||||||
tags = ["requires-mem:28g"],
|
tags = ["requires-mem:28g"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":activations",
|
||||||
|
":attention",
|
||||||
":configs",
|
":configs",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/configs.h" // kMaxQKVDim
|
#include "gemma/configs.h" // kMaxQKVDim
|
||||||
#include "gemma/gemma.h"
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,10 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/query.h"
|
||||||
|
#include "gemma/weights.h"
|
||||||
|
#include "ops/matmul.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/configs.h" // kMaxQKVDim
|
#include "gemma/configs.h" // kMaxQKVDim
|
||||||
#include "gemma/gemma.h"
|
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
|
@ -114,8 +113,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
constexpr size_t offset = 0; // placeholder, do not remove
|
constexpr size_t offset = 0; // placeholder, do not remove
|
||||||
const size_t pos =
|
const size_t pos = qbatch.Pos(qi) + batch_idx + offset;
|
||||||
qbatch.Pos(qi) + batch_idx + offset;
|
|
||||||
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (query_norm_scale.HasPtr()) {
|
if (query_norm_scale.HasPtr()) {
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "gemma/flash_structs.h"
|
#include "gemma/flash_structs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/query.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue