diff --git a/BUILD.bazel b/BUILD.bazel index a9631dc..ce78055 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -135,6 +135,8 @@ cc_test( name = "flash_attention_test", srcs = ["gemma/flash_attention_test.cc"], deps = [ + ":activations", + ":attention", ":configs", ":gemma_args", ":gemma_lib", @@ -439,6 +441,7 @@ cc_test( # for test_suite. tags = ["ops_tests"], deps = [ + ":activations", ":allocator", ":basics", ":configs", @@ -537,6 +540,21 @@ cc_test( ], ) +cc_library( + name = "activations", + hdrs = ["gemma/activations.h"], + deps = [ + ":basics", + ":configs", + ":gemma_args", + ":kv_cache", + ":mat", + ":ops", + ":tensor_stats", + ":threading_context", + ], +) + cc_library( name = "query", hdrs = ["gemma/query.h"], @@ -573,31 +591,21 @@ cc_test( ) cc_library( - name = "gemma_lib", + name = "attention", srcs = [ "gemma/attention.cc", "gemma/flash_attention.cc", - "gemma/gemma.cc", - "gemma/tensor_stats.cc", - "gemma/vit.cc", ], hdrs = [ - "gemma/activations.h", "gemma/attention.h", "gemma/flash_attention.h", "gemma/flash_structs.h", - "gemma/gemma.h", - "gemma/tensor_stats.h", - "gemma/vit.h", ], - exec_properties = { - # Avoid linker OOMs when building with sanitizer instrumentation. - "mem": "28g", - }, textual_hdrs = [ "gemma/gemma-inl.h", ], deps = [ + ":activations", ":allocator", ":basics", ":configs", @@ -606,10 +614,71 @@ cc_library( ":mat", ":matmul", ":matmul_env", + ":ops", + ":query", + ":tensor_stats", + ":threading", + ":threading_context", + ":weights", + ":zones", + "//compression:compress", + "//compression:types", + "//io", + "@highway//:bit_set", + "@highway//:hwy", + "@highway//:nanobenchmark", # timer + "@highway//:profiler", + "@highway//:thread_pool", + "@highway//hwy/contrib/sort:vqsort", + ], +) + +cc_library( + name = "tensor_stats", + srcs = ["gemma/tensor_stats.cc"], + hdrs = ["gemma/tensor_stats.h"], + deps = [ + ":basics", + ":mat", + ":ops", + ":threading_context", + ":zones", + "//compression:compress", + "//io", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:stats", + ], +) + +cc_library( + name = "gemma_lib", + srcs = [ + "gemma/gemma.cc", + "gemma/vit.cc", + ], + hdrs = [ + "gemma/gemma.h", + "gemma/vit.h", + ], + exec_properties = { + # Avoid linker OOMs when building with sanitizer instrumentation. + "mem": "28g", + }, + deps = [ + ":activations", + ":allocator", + ":attention", + ":basics", + ":configs", + ":gemma_args", + ":kv_cache", + ":mat", + ":matmul_env", ":model_store", ":ops", ":query", - ":threading", + ":tensor_stats", ":threading_context", ":weights", ":zones", @@ -618,7 +687,6 @@ cc_library( "//io", "//io:blob_store", "//paligemma:image", - "@highway//:bit_set", "@highway//:hwy", "@highway//:nanobenchmark", # timer "@highway//:profiler", @@ -636,6 +704,8 @@ cc_test( # MatMulEnvs are up to 20GB large. tags = ["requires-mem:28g"], deps = [ + ":activations", + ":attention", ":configs", ":gemma_args", ":gemma_lib", diff --git a/gemma/attention.cc b/gemma/attention.cc index eccfd25..b5dc5e1 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -26,7 +26,6 @@ #include "gemma/activations.h" #include "gemma/configs.h" // kMaxQKVDim -#include "gemma/gemma.h" #include "gemma/weights.h" #include "util/threading.h" #include "util/threading_context.h" diff --git a/gemma/attention.h b/gemma/attention.h index 60e6823..71411b2 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -20,7 +20,10 @@ #include -#include "gemma/gemma.h" +#include "gemma/activations.h" +#include "gemma/query.h" +#include "gemma/weights.h" +#include "ops/matmul.h" #include "hwy/highway.h" namespace gcpp { diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 671efb4..b1d079e 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -35,7 +35,6 @@ #include "gemma/activations.h" #include "gemma/configs.h" // kMaxQKVDim -#include "gemma/gemma.h" #include "util/threading.h" #include "hwy/profiler.h" @@ -114,8 +113,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, // Find the token position in the query and calculate // the range of cache positions to attend to. constexpr size_t offset = 0; // placeholder, do not remove - const size_t pos = - qbatch.Pos(qi) + batch_idx + offset; + const size_t pos = qbatch.Pos(qi) + batch_idx + offset; float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim; // Apply rope and scaling to Q. if (query_norm_scale.HasPtr()) { diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index b8a70ea..8f3ec21 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -23,7 +23,7 @@ #include #include "gemma/flash_structs.h" -#include "gemma/gemma.h" +#include "gemma/query.h" #include "hwy/highway.h" namespace gcpp {