mirror of https://github.com/google/gemma.cpp.git
Separate monolithic gemma_lib library into more specific cc_library targets.
Creates new cc_library targets for :attention, :tensor_stats and :activations. Eliminates cyclic dependencies between these libraries. PiperOrigin-RevId: 845238905
This commit is contained in:
parent
baa69dfb78
commit
85e2e8ae7f
98
BUILD.bazel
98
BUILD.bazel
|
|
@ -135,6 +135,8 @@ cc_test(
|
|||
name = "flash_attention_test",
|
||||
srcs = ["gemma/flash_attention_test.cc"],
|
||||
deps = [
|
||||
":activations",
|
||||
":attention",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
|
|
@ -439,6 +441,7 @@ cc_test(
|
|||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":activations",
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
|
|
@ -537,6 +540,21 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "activations",
|
||||
hdrs = ["gemma/activations.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":ops",
|
||||
":tensor_stats",
|
||||
":threading_context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "query",
|
||||
hdrs = ["gemma/query.h"],
|
||||
|
|
@ -573,31 +591,21 @@ cc_test(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
name = "attention",
|
||||
srcs = [
|
||||
"gemma/attention.cc",
|
||||
"gemma/flash_attention.cc",
|
||||
"gemma/gemma.cc",
|
||||
"gemma/tensor_stats.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/attention.h",
|
||||
"gemma/flash_attention.h",
|
||||
"gemma/flash_structs.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/tensor_stats.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
textual_hdrs = [
|
||||
"gemma/gemma-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":activations",
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
|
|
@ -606,10 +614,71 @@ cc_library(
|
|||
":mat",
|
||||
":matmul",
|
||||
":matmul_env",
|
||||
":ops",
|
||||
":query",
|
||||
":tensor_stats",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":weights",
|
||||
":zones",
|
||||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"//io",
|
||||
"@highway//:bit_set",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_stats",
|
||||
srcs = ["gemma/tensor_stats.cc"],
|
||||
hdrs = ["gemma/tensor_stats.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":mat",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":zones",
|
||||
"//compression:compress",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/gemma.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/gemma.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":activations",
|
||||
":allocator",
|
||||
":attention",
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":matmul_env",
|
||||
":model_store",
|
||||
":ops",
|
||||
":query",
|
||||
":threading",
|
||||
":tensor_stats",
|
||||
":threading_context",
|
||||
":weights",
|
||||
":zones",
|
||||
|
|
@ -618,7 +687,6 @@ cc_library(
|
|||
"//io",
|
||||
"//io:blob_store",
|
||||
"//paligemma:image",
|
||||
"@highway//:bit_set",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
|
|
@ -636,6 +704,8 @@ cc_test(
|
|||
# MatMulEnvs are up to 20GB large.
|
||||
tags = ["requires-mem:28g"],
|
||||
deps = [
|
||||
":activations",
|
||||
":attention",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h" // kMaxQKVDim
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/query.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@
|
|||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h" // kMaxQKVDim
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -114,8 +113,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
const size_t pos =
|
||||
qbatch.Pos(qi) + batch_idx + offset;
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx + offset;
|
||||
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||
// Apply rope and scaling to Q.
|
||||
if (query_norm_scale.HasPtr()) {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include <cstdint>
|
||||
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/query.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
Loading…
Reference in New Issue