diff --git a/BUILD.bazel b/BUILD.bazel index 491939a..130f18f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -608,7 +608,9 @@ cc_library( ], deps = [ ":activations", + ":basics", ":configs", + ":kv_cache", ":mat", ":matmul", ":matmul_env", diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 2129a63..ba985fc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -22,10 +22,14 @@ #include #include #include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "gemma/flash_structs.h" +#include "gemma/kv_cache.h" +#include "gemma/query.h" +#include "util/basics.h" #include "util/threading_context.h" #include "util/zones.h" #include "hwy/base.h" diff --git a/util/test_util.h b/util/test_util.h index f0c37f9..19342e4 100644 --- a/util/test_util.h +++ b/util/test_util.h @@ -115,7 +115,7 @@ template void PrintMatPtr(MatPtrT mat) { for (int i = 0; i < mat.Rows(); ++i) { for (int j = 0; j < mat.Cols(); ++j) { - std::cerr << mat.Row(i)[j] << " ,"; + std::cerr << hwy::ConvertScalarTo(mat.Row(i)[j]) << " ,"; } std::cerr << std::endl; }