mirror of https://github.com/google/gemma.cpp.git
Memory use reduction: smaller/single MMStorage
PiperOrigin-RevId: 804865029
This commit is contained in:
parent
06e5da1e22
commit
a5ab99e4ba
|
|
@ -179,7 +179,7 @@ struct Activations {
|
||||||
|
|
||||||
MatStorageT<float> x; // input
|
MatStorageT<float> x; // input
|
||||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||||
MatStorageT<float> logits;
|
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
|
||||||
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
||||||
|
|
||||||
// Gated FFW
|
// Gated FFW
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,7 @@ class QBatch {
|
||||||
max_size_(max_size),
|
max_size_(max_size),
|
||||||
queries_(queries),
|
queries_(queries),
|
||||||
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||||
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
|
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
||||||
HWY_DASSERT(size_ != 0);
|
HWY_DASSERT(size_ != 0);
|
||||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -797,9 +797,10 @@ class MMImpl {
|
||||||
return View(A, 0, 0, A.Cols());
|
return View(A, 0, 0, A.Cols());
|
||||||
} else {
|
} else {
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
// support a separate F32 kernel; most A are already BF16.
|
// support a separate F32 kernel; most A are already BF16. We also only
|
||||||
const StridedViewBF A_view =
|
// have a single MMStorage.
|
||||||
args.env->storage[args.options.cluster_idx].A(A.Extents());
|
HWY_ASSERT(args.options.cluster_idx == 0);
|
||||||
|
const StridedViewBF A_view = args.env->storage.A(A.Extents());
|
||||||
DecompressA(A, A_view, args);
|
DecompressA(A, A_view, args);
|
||||||
return A_view;
|
return A_view;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -346,14 +346,10 @@ std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
||||||
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
|
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
|
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
|
||||||
// Create storage per cluster. This only applies to in-cluster parallelism.
|
|
||||||
// For nested and sequential parallelism, a single MMStorage is used.
|
|
||||||
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
||||||
per_cluster.resize(num_clusters);
|
per_cluster.resize(num_clusters);
|
||||||
storage.reserve(num_clusters);
|
|
||||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||||
storage.push_back(MMStorage(ctx.allocator));
|
|
||||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -306,11 +306,11 @@ using StridedViewD = StridedView<double>;
|
||||||
class MMStorage {
|
class MMStorage {
|
||||||
public:
|
public:
|
||||||
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
||||||
// and reusing it across `MatMul` calls.
|
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
|
||||||
static constexpr size_t kMaxK = 64 * 1024;
|
static constexpr size_t kMaxK = 36 * 1024;
|
||||||
|
|
||||||
MMStorage(const Allocator& allocator)
|
MMStorage(const Allocator& allocator)
|
||||||
// 0.5 GiB. Must be padded, see `DoDecompressA`.
|
// 288 MiB. Must be padded, see `DoDecompressA`.
|
||||||
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
|
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
|
||||||
MatPadding::kOdd) {}
|
MatPadding::kOdd) {}
|
||||||
|
|
||||||
|
|
@ -673,7 +673,7 @@ struct MatMulEnv {
|
||||||
// Whether to print the best config immediately after autotuning finished.
|
// Whether to print the best config immediately after autotuning finished.
|
||||||
bool print_best = false;
|
bool print_best = false;
|
||||||
|
|
||||||
std::vector<MMStorage> storage;
|
MMStorage storage;
|
||||||
|
|
||||||
struct PerCluster {
|
struct PerCluster {
|
||||||
MMKeys keys;
|
MMKeys keys;
|
||||||
|
|
|
||||||
|
|
@ -614,6 +614,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
}
|
}
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
|
// TODO: support bf16 logits using Decompress2.
|
||||||
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||||
const size_t worker,
|
const size_t worker,
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f) {
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// TODO: extend to 16k after updating non_eos.
|
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
||||||
|
|
||||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue