mirror of https://github.com/google/gemma.cpp.git
Memory use reduction: smaller/single MMStorage
PiperOrigin-RevId: 804865029
This commit is contained in:
parent
06e5da1e22
commit
a5ab99e4ba
|
|
@ -179,7 +179,7 @@ struct Activations {
|
|||
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||
MatStorageT<float> logits;
|
||||
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
|
||||
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
||||
|
||||
// Gated FFW
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class QBatch {
|
|||
max_size_(max_size),
|
||||
queries_(queries),
|
||||
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
|
||||
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -797,9 +797,10 @@ class MMImpl {
|
|||
return View(A, 0, 0, A.Cols());
|
||||
} else {
|
||||
// Always decompress. To reduce code size/compile time, we no longer
|
||||
// support a separate F32 kernel; most A are already BF16.
|
||||
const StridedViewBF A_view =
|
||||
args.env->storage[args.options.cluster_idx].A(A.Extents());
|
||||
// support a separate F32 kernel; most A are already BF16. We also only
|
||||
// have a single MMStorage.
|
||||
HWY_ASSERT(args.options.cluster_idx == 0);
|
||||
const StridedViewBF A_view = args.env->storage.A(A.Extents());
|
||||
DecompressA(A, A_view, args);
|
||||
return A_view;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -346,14 +346,10 @@ std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
|||
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
|
||||
}
|
||||
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
|
||||
// Create storage per cluster. This only applies to in-cluster parallelism.
|
||||
// For nested and sequential parallelism, a single MMStorage is used.
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
|
||||
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
||||
per_cluster.resize(num_clusters);
|
||||
storage.reserve(num_clusters);
|
||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||
storage.push_back(MMStorage(ctx.allocator));
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -306,11 +306,11 @@ using StridedViewD = StridedView<double>;
|
|||
class MMStorage {
|
||||
public:
|
||||
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
||||
// and reusing it across `MatMul` calls.
|
||||
static constexpr size_t kMaxK = 64 * 1024;
|
||||
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
|
||||
static constexpr size_t kMaxK = 36 * 1024;
|
||||
|
||||
MMStorage(const Allocator& allocator)
|
||||
// 0.5 GiB. Must be padded, see `DoDecompressA`.
|
||||
// 288 MiB. Must be padded, see `DoDecompressA`.
|
||||
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
|
||||
MatPadding::kOdd) {}
|
||||
|
||||
|
|
@ -673,7 +673,7 @@ struct MatMulEnv {
|
|||
// Whether to print the best config immediately after autotuning finished.
|
||||
bool print_best = false;
|
||||
|
||||
std::vector<MMStorage> storage;
|
||||
MMStorage storage;
|
||||
|
||||
struct PerCluster {
|
||||
MMKeys keys;
|
||||
|
|
|
|||
|
|
@ -614,6 +614,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
// TODO: support bf16 logits using Decompress2.
|
||||
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||
const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// TODO: extend to 16k after updating non_eos.
|
||||
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
||||
|
||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||
|
|
|
|||
Loading…
Reference in New Issue