Separate monolithic gemma_lib library into more specific cc_library targets.

Creates new cc_library targets for :attention, :tensor_stats and :activations. Eliminates cyclic dependencies between these libraries.

PiperOrigin-RevId: 845238905
This commit is contained in:
Balazs Racz 2025-12-16 06:22:24 -08:00 committed by Copybara-Service
parent baa69dfb78
commit 85e2e8ae7f
5 changed files with 90 additions and 20 deletions

View File

@ -135,6 +135,8 @@ cc_test(
name = "flash_attention_test",
srcs = ["gemma/flash_attention_test.cc"],
deps = [
":activations",
":attention",
":configs",
":gemma_args",
":gemma_lib",
@ -439,6 +441,7 @@ cc_test(
# for test_suite.
tags = ["ops_tests"],
deps = [
":activations",
":allocator",
":basics",
":configs",
@ -537,6 +540,21 @@ cc_test(
],
)
cc_library(
name = "activations",
hdrs = ["gemma/activations.h"],
deps = [
":basics",
":configs",
":gemma_args",
":kv_cache",
":mat",
":ops",
":tensor_stats",
":threading_context",
],
)
cc_library(
name = "query",
hdrs = ["gemma/query.h"],
@ -573,31 +591,21 @@ cc_test(
)
cc_library(
name = "gemma_lib",
name = "attention",
srcs = [
"gemma/attention.cc",
"gemma/flash_attention.cc",
"gemma/gemma.cc",
"gemma/tensor_stats.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/flash_structs.h",
"gemma/gemma.h",
"gemma/tensor_stats.h",
"gemma/vit.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
textual_hdrs = [
"gemma/gemma-inl.h",
],
deps = [
":activations",
":allocator",
":basics",
":configs",
@ -606,10 +614,71 @@ cc_library(
":mat",
":matmul",
":matmul_env",
":ops",
":query",
":tensor_stats",
":threading",
":threading_context",
":weights",
":zones",
"//compression:compress",
"//compression:types",
"//io",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
cc_library(
name = "tensor_stats",
srcs = ["gemma/tensor_stats.cc"],
hdrs = ["gemma/tensor_stats.h"],
deps = [
":basics",
":mat",
":ops",
":threading_context",
":zones",
"//compression:compress",
"//io",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:stats",
],
)
cc_library(
name = "gemma_lib",
srcs = [
"gemma/gemma.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/gemma.h",
"gemma/vit.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":activations",
":allocator",
":attention",
":basics",
":configs",
":gemma_args",
":kv_cache",
":mat",
":matmul_env",
":model_store",
":ops",
":query",
":threading",
":tensor_stats",
":threading_context",
":weights",
":zones",
@ -618,7 +687,6 @@ cc_library(
"//io",
"//io:blob_store",
"//paligemma:image",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
@ -636,6 +704,8 @@ cc_test(
# MatMulEnvs are up to 20GB large.
tags = ["requires-mem:28g"],
deps = [
":activations",
":attention",
":configs",
":gemma_args",
":gemma_lib",

View File

@ -26,7 +26,6 @@
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "util/threading_context.h"

View File

@ -20,7 +20,10 @@
#include <stddef.h>
#include "gemma/gemma.h"
#include "gemma/activations.h"
#include "gemma/query.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "hwy/highway.h"
namespace gcpp {

View File

@ -35,7 +35,6 @@
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "util/threading.h"
#include "hwy/profiler.h"
@ -114,8 +113,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
// Find the token position in the query and calculate
// the range of cache positions to attend to.
constexpr size_t offset = 0; // placeholder, do not remove
const size_t pos =
qbatch.Pos(qi) + batch_idx + offset;
const size_t pos = qbatch.Pos(qi) + batch_idx + offset;
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
// Apply rope and scaling to Q.
if (query_norm_scale.HasPtr()) {

View File

@ -23,7 +23,7 @@
#include <cstdint>
#include "gemma/flash_structs.h"
#include "gemma/gemma.h"
#include "gemma/query.h"
#include "hwy/highway.h"
namespace gcpp {