diff --git a/gemma/activations.h b/gemma/activations.h index 21e5e58..71523e4 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -179,7 +179,7 @@ struct Activations { MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul - MatStorageT logits; + MatStorageT logits; // TODO: BF16 after Softmax supports that. MatStorageT sampled; // batch_size x 3 (padded) // Gated FFW diff --git a/gemma/gemma.h b/gemma/gemma.h index 2f06ab8..491999d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -127,7 +127,7 @@ class QBatch { max_size_(max_size), queries_(queries), size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { - HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`. + HWY_ASSERT(max_size_ <= kMaxBatchSize); HWY_DASSERT(size_ != 0); HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 65dc185..74deb78 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -797,9 +797,10 @@ class MMImpl { return View(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer - // support a separate F32 kernel; most A are already BF16. - const StridedViewBF A_view = - args.env->storage[args.options.cluster_idx].A(A.Extents()); + // support a separate F32 kernel; most A are already BF16. We also only + // have a single MMStorage. + HWY_ASSERT(args.options.cluster_idx == 0); + const StridedViewBF A_view = args.env->storage.A(A.Extents()); DecompressA(A, A_view, args); return A_view; } diff --git a/ops/matmul.cc b/ops/matmul.cc index 00330e5..66ce0df 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -346,14 +346,10 @@ std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)(); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { - // Create storage per cluster. This only applies to in-cluster parallelism. - // For nested and sequential parallelism, a single MMStorage is used. +MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) { const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); per_cluster.resize(num_clusters); - storage.reserve(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - storage.push_back(MMStorage(ctx.allocator)); row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } diff --git a/ops/matmul.h b/ops/matmul.h index 641dad9..c86ecc3 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -306,11 +306,11 @@ using StridedViewD = StridedView; class MMStorage { public: // Compile-time bounds on matrix columns to enable pre-allocating storage - // and reusing it across `MatMul` calls. - static constexpr size_t kMaxK = 64 * 1024; + // and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B. + static constexpr size_t kMaxK = 36 * 1024; MMStorage(const Allocator& allocator) - // 0.5 GiB. Must be padded, see `DoDecompressA`. + // 288 MiB. Must be padded, see `DoDecompressA`. : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd) {} @@ -673,7 +673,7 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - std::vector storage; + MMStorage storage; struct PerCluster { MMKeys keys; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 18ee40f..0c6bd50 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -614,6 +614,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( } // See below for a specialized version for top-1 sampling. +// TODO: support bf16 logits using Decompress2. static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, const size_t worker, float temperature = 1.0f) { diff --git a/util/basics.h b/util/basics.h index 7b1c7d3..0211a0e 100644 --- a/util/basics.h +++ b/util/basics.h @@ -30,7 +30,7 @@ namespace gcpp { -// TODO: extend to 16k after updating non_eos. +// For hwy::BitSet4096. Note that KVs are extremely large for such batches. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };