Fix RowT issue and improve Griffin (currently still broken)

Use type-safe MatPtrT via dynamic_cast, avoid/remove unsafe RowT
activations: Griffin tensors are now padded
Griffin: add batching support, fix conv1d_cache allocation
weights: bundle to TensorToRead, add kNoPad flag, fix SplitW1
const-correct fix for ForEachTensor
blob_store: move BlobIO2 to .cc and rename BlobIO
PiperOrigin-RevId: 760610094
This commit is contained in:
Jan Wassenberg 2025-05-19 07:01:29 -07:00 committed by Copybara-Service
parent d6cfabc2c1
commit cb188d4a0e
13 changed files with 218 additions and 211 deletions

View File

@ -27,6 +27,8 @@ namespace gcpp {
namespace {
using MatPtrF = MatPtrT<float>;
// Split into two classes so that ForEachTensor only requires two "other"
// arguments. This is anyway useful for locality, because `grad` only feeds
// into `grad_m` and `grad_v` here.
@ -40,12 +42,11 @@ class AdamUpdateMV {
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
void operator()(const MatPtr& grad, const MatPtr& grad_m,
const MatPtr& grad_v) {
void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) {
for (size_t r = 0; r < grad.Rows(); ++r) {
const float* HWY_RESTRICT g = grad.RowT<float>(r);
float* HWY_RESTRICT m = grad_m.MutableRowT<float>(r);
float* HWY_RESTRICT v = grad_v.MutableRowT<float>(r);
const float* HWY_RESTRICT g = grad.Row(r);
float* HWY_RESTRICT m = grad_m.Row(r);
float* HWY_RESTRICT v = grad_v.Row(r);
for (size_t c = 0; c < grad.Cols(); ++c) {
m[c] *= beta1_;
m[c] += cbeta1_ * g[c];
@ -73,11 +74,12 @@ class AdamUpdateW {
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
epsilon_(epsilon) {}
void operator()(MatPtr& weights, const MatPtr& grad_m, const MatPtr& grad_v) {
void operator()(MatPtrF& weights, const MatPtrF& grad_m,
const MatPtrF& grad_v) {
for (size_t r = 0; r < weights.Rows(); ++r) {
float* HWY_RESTRICT w = weights.RowT<float>(r);
const float* HWY_RESTRICT m = grad_m.RowT<float>(r);
const float* HWY_RESTRICT v = grad_v.RowT<float>(r);
float* HWY_RESTRICT w = weights.Row(r);
const float* HWY_RESTRICT m = grad_m.Row(r);
const float* HWY_RESTRICT v = grad_v.Row(r);
for (size_t c = 0; c < weights.Cols(); ++c) {
const float mhat = m[c] * norm1_;
const float vhat = v[c] * norm2_;
@ -100,12 +102,18 @@ void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
AdamUpdateMV update_mv(beta1, beta2, t);
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
update_mv(t.mat, *t.other_mat1, *t.other_mat2);
const MatPtrF grad_f(t.mat);
MatPtrF grad_m_f(*t.other_mat1);
MatPtrF grad_v_f(*t.other_mat2);
update_mv(grad_f, grad_m_f, grad_v_f);
});
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
update_w(t.mat, *t.other_mat1, *t.other_mat2);
MatPtrF weights_f(t.mat);
const MatPtrF grad_m_f(*t.other_mat1);
const MatPtrF grad_v_f(*t.other_mat2);
update_w(weights_f, grad_m_f, grad_v_f);
});
}

View File

@ -33,8 +33,7 @@ struct Activations {
layer_config(config.layer_configs[0]),
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()),
is_griffin(layer_config.type ==
LayerAttentionType::kGriffinRecurrentBlock),
is_griffin(config.model == Model::GRIFFIN_2B),
x("x", Extents2D(batch_size, config.model_dim), pad_),
q("q",
@ -58,21 +57,18 @@ struct Activations {
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
// No padding for Griffin because it does not always use Row().
griffin_x("griffin_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
pad_),
griffin_y("griffin_y",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
pad_),
griffin_gate_x(
"griffin_gate_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
griffin_multiplier(
"griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
inv_timescale(CreateInvTimescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim,

View File

@ -21,7 +21,6 @@
#include <stdio.h>
#include <algorithm> // std::min
#include <memory> // std::make_unique
#include <vector>
#include "gemma/activations.h"
@ -69,30 +68,44 @@ namespace HWY_NAMESPACE {
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`.
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
// TODO: add batch query support for Griffin (QueriesPos).
template <typename T>
HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
size_t layer, Activations& activations,
HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
size_t num_tokens, size_t griffin_layer,
Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
KVCache& kv_cache = kv_caches[0];
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D df;
const size_t model_dim = layer_weights->layer_config.model_dim;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t heads = layer_weights->layer_config.heads;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_queries = queries_pos.size();
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
const size_t num_interleaved = num_tokens * num_queries;
// X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
// TODO: MatMul
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(batch_idx),
activations.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
@ -100,18 +113,19 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
}
// Conv1D.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
HWY_FULL(float) df;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
kv_cache.conv1d_cache.Row(layer) +
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
@ -119,7 +133,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df);
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
@ -136,17 +149,20 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
}
// RGLRU
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer);
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
float* HWY_RESTRICT rnn_state =
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop(
layer_weights->griffin.gate_w, kMatrixSize * head,
@ -166,7 +182,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
@ -186,17 +201,18 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
hn::Store(pre_out, df, x + head_offset + i);
}
});
}
} // interleaved_idx
// Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
float* out_ptr = activations.att_sums.Row(batch_idx);
// TODO: MatMul
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
float* out_ptr = activations.att_sums.Row(r);
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
pool);
}
}
} // GriffinRecurrent
// Wrapper class; holds arguments in member variables to shorten call sites.
template <typename T>
@ -219,11 +235,7 @@ class GemmaAttention {
activations_.weights_config.attention_window_sizes[layer] ==
activations_.seq_len;
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer &&
(activations_.weights_config.model == Model::GEMMA3_4B ||
activations_.weights_config.model == Model::GEMMA3_12B ||
activations_.weights_config.model == Model::GEMMA3_27B ||
activations_.weights_config.model == Model::GEMMA3_1B)) {
if (is_global_layer && IsVLM(activations_.weights_config.model)) {
inv_timescale = activations_.inv_timescale_global.Packed();
}
// PostQKType::Rope
@ -454,7 +466,6 @@ class GemmaAttention {
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
pos, start_pos, last_pos);
});
}
@ -587,14 +598,12 @@ HWY_NOINLINE void Attention(
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches)();
} else {
// Only reached if the model is Griffin.
// The kv_caches are allocated only for the griffin layers, so we need to
// map the layer index to the griffin layer index.
auto type = layer_weights->layer_config.type;
size_t layer_of_type =
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer);
HWY_ASSERT(queries_pos.size() == 1);
GriffinRecurrent(queries_pos[0], num_tokens, layer_of_type, activations,
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
layer_weights, kv_caches);
}
}
@ -1056,7 +1065,6 @@ HWY_NOINLINE void Prefill(
// threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul.
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
HWY_DASSERT(max_tbatch_size <= activations.x.Rows());
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
@ -1397,6 +1405,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) {
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) {
if (queries_pos_in[i] == 0) {
@ -1510,17 +1519,10 @@ void GenerateBatchT(const ModelStore& model,
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries);
// Griffin does not support query batching.
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
for (const LayerConfig& layer_config : model.Config().layer_configs) {
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
max_qbatch_size = 1;
break;
}
}
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(model.Config(), max_batch_size, env);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
@ -1528,7 +1530,6 @@ void GenerateBatchT(const ModelStore& model,
// Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
activations.SetBatchSize(qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);

View File

@ -25,8 +25,9 @@
namespace gcpp {
void KVCache::ZeroGriffinCache() {
if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache);
if (rglru_cache.HasPtr()) ZeroInit(rglru_cache);
if (griffin_layers == 0) return;
ZeroInit(conv1d_cache);
ZeroInit(rglru_cache);
}
static size_t GriffinConv1dCols(const ModelConfig& config) {
@ -34,7 +35,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
}
return conv1d_width == 0 ? 0 : conv1d_width - 1;
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
// hence allocate conv1d_width * model_dim total columns.
return conv1d_width * config.model_dim;
}
// prefill_tbatch_size is the maximum number of tokens from one query to
@ -42,12 +45,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
: griffin_layers(
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
griffin_conv1d_cols(GriffinConv1dCols(config)),
// TODO(patrickms): Add query batching support for Griffin.
conv1d_cache(
"conv1d_cache",
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
MatPadding::kOdd),
conv1d_cache("conv1d_cache",
Extents2D(griffin_layers, GriffinConv1dCols(config)),
MatPadding::kOdd),
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
MatPadding::kOdd) {
// TODO: move to MatStorageT.

View File

@ -31,7 +31,6 @@ struct KVCache {
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
size_t griffin_layers = 0;
size_t griffin_conv1d_cols = 0;
// griffin_layers, griffin_conv1d_cols * config.model_dim
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim

View File

@ -97,21 +97,27 @@ void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
SplitW1NUQ(layer_config);
}
struct TensorToRead {
MatPtr* mat;
BlobRange range;
// Some tensors opt out of padding via kNoPad flags.
MatPadding padding;
};
// Allocates multiple in parallel and binds to NUMA nodes.
static void AllocateAndBindAll(const std::vector<MatPtr*>& mats,
MatPadding padding,
static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
std::vector<MatOwner>& owners,
hwy::ThreadPool& pool) {
const size_t start = owners.size();
owners.resize(start + mats.size());
owners.resize(start + tensors.size());
MMParallel parallel(ThreadingContext::Get());
// Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
owners[start + task].AllocateFor(*mats[task], padding);
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
// TODO(janwas): MatMul outputs will later also be BF16.
BindB(*mats[task], sizeof(float), parallel);
BindB(*tensors[task].mat, sizeof(float), parallel);
});
}
@ -155,55 +161,54 @@ MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
return MapPtr();
}
static void MapAll(const std::vector<MatPtr*>& mats,
const std::vector<BlobRange>& ranges, const MapPtr& mapped) {
static void MapAll(const std::vector<TensorToRead>& tensors,
const MapPtr& mapped) {
PROFILER_ZONE("Startup.Weights.Map");
for (size_t i = 0; i < mats.size(); ++i) {
for (size_t i = 0; i < tensors.size(); ++i) {
// SetPtr does not change the stride, but it is expected to be packed
// because that is what Compress() writes to the file.
const size_t mat_bytes = mats[i]->PackedBytes();
const size_t mat_bytes = tensors[i].mat->PackedBytes();
// Ensure blob size matches that computed from metadata.
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name());
mats[i]->SetPtr(const_cast<uint8_t*>(mapped.get() + ranges[i].offset),
mats[i]->Stride());
tensors[i].mat->SetPtr(
const_cast<uint8_t*>(mapped.get() + tensors[i].range.offset),
tensors[i].mat->Stride());
}
}
std::vector<IOBatch> MakeBatches(const std::vector<BlobRange>& ranges,
const std::vector<MatPtr*>& mats,
std::vector<IOBatch> MakeBatches(const std::vector<TensorToRead>& tensors,
const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches");
// Batches must be contiguous but blobs are padded, hence at least one
// batch per tensor, and more when tensor rows exceed the batch size.
std::vector<IOBatch> batches;
batches.reserve(mats.size());
batches.reserve(tensors.size());
for (size_t i = 0; i < mats.size(); ++i) {
uint64_t offset = ranges[i].offset;
HWY_ASSERT(ranges[i].End() <= file_bytes);
for (size_t i = 0; i < tensors.size(); ++i) {
const BlobRange& range = tensors[i].range;
MatPtr& mat = *tensors[i].mat;
uint64_t offset = range.offset;
HWY_ASSERT(range.End() <= file_bytes);
batches.emplace_back(offset, ranges[i].key_idx);
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
// Caution, `RowT` requires knowledge of the actual type. We instead use
// the first row, which is the same for any type, and advance the *byte*
// pointer by the *byte* stride.
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
uint8_t* row = mats[i]->RowT<uint8_t>(0);
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, ranges[i].key_idx);
batches.emplace_back(offset, range.key_idx);
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
uint8_t* row_bytes = mat.RowBytes(0);
for (size_t r = 0; r < mat.Rows(); ++r) {
if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, range.key_idx);
// Adding to an empty batch is always successful.
HWY_ASSERT(batches.back().Add(row, file_bytes_per_row));
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
}
offset += file_bytes_per_row;
row += mem_stride_bytes;
row_bytes += mem_stride_bytes;
// Keep the in-memory row padding uninitialized so msan detects any use.
}
HWY_ASSERT(offset == ranges[i].End());
HWY_ASSERT(offset == range.End());
}
HWY_ASSERT(batches.size() >= mats.size());
HWY_ASSERT(batches.size() >= tensors.size());
return batches;
}
@ -228,16 +233,14 @@ static void ReadBatches(const BlobReader& reader,
}
// Aborts on error.
static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges, Tristate map,
static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
BlobReader& reader, Tristate map,
std::vector<MatOwner>& mat_owners,
const MatPadding padding, hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size());
hwy::ThreadPool& pool) {
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
if (mapped) {
MapAll(mats, ranges, mapped);
MapAll(tensors, mapped);
return;
}
} // otherwise fall through to read mode
@ -245,32 +248,32 @@ static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
{
PROFILER_ZONE("Startup.Weights.Allocate");
// NOTE: this changes the stride of `mats`!
AllocateAndBindAll(mats, padding, mat_owners, pool);
AllocateAndBindAll(tensors, mat_owners, pool);
}
const std::vector<IOBatch> batches =
MakeBatches(ranges, mats, reader.file_bytes());
MakeBatches(tensors, reader.file_bytes());
ReadBatches(reader, batches, pool);
}
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
Tristate map, hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
std::vector<MatPtr*> mats;
std::vector<BlobRange> ranges;
// Padding is inserted when reading row by row, except for NUQ tensors.
const MatPadding padding = MatPadding::kOdd;
std::vector<TensorToRead> tensors;
AllocatePointer(model.Config());
// Enumerate all weights (negligible cost).
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
const MatPadding padding = (t.flags & TensorArgs::kNoPad)
? MatPadding::kPacked
: MatPadding::kOdd;
size_t key_idx;
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
mats.push_back(&t.mat);
ranges.push_back(reader.Range(key_idx));
tensors.push_back({.mat = &t.mat,
.range = reader.Range(key_idx),
.padding = padding});
return;
}
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
@ -278,7 +281,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
});
});
MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool);
MapOrReadAll(tensors, reader, map, mat_owners_, pool);
Fixup(pool);
}
@ -338,9 +341,9 @@ void WeightsOwner::LogWeightStatsF32() {
printf("[scale=%f] ", t.mat.Scale());
}
hwy::Stats stats;
HWY_ASSERT(t.mat.GetType() == Type::kF32);
const MatPtrT<float> mat_f(t.mat);
for (size_t r = 0; r < t.mat.Rows(); ++r) {
const float* HWY_RESTRICT row = t.mat.RowT<float>(r);
const float* HWY_RESTRICT row = mat_f.Row(r);
for (size_t c = 0; c < t.mat.Cols(); ++c) {
stats.Notify(row[c]);
}

View File

@ -43,24 +43,27 @@ struct TensorArgs {
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
// `flags` is a combination of zero or more `Flags`.
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
int flags)
TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags)
: mat(mat),
other_mat1(other_mat1),
other_mat2(other_mat2),
flags(flags) {}
MatPtr& mat;
const MatPtr* other_mat1; // either/both can be nullptr.
const MatPtr* other_mat2;
MatPtr* other_mat1; // either/both can be nullptr.
MatPtr* other_mat2;
// TODO: freestanding enum class instead? These are mutually exclusive.
enum Flags {
// Read the tensor from the file and abort if it is not found.
// Default: Read the tensor from the file and abort if it is not found.
kMustRead = 0,
// Not an error if the tensor is not present in the file. For example,
// the _w1/_w2 tensors are not always present.
kMaybeRead = 1,
// Avoid padding tensor rows when reading. Used for some Griffin tensors
// whose index computations do not use Row() accessors.
kNoPad = 2,
};
const int flags;
};
@ -214,8 +217,8 @@ struct LayerWeightsPtrs {
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
// Public because also called by `WeightsPtrs`.
template <class Func>
void ForEachTensor(const LayerWeightsPtrs<Weight>* other1,
const LayerWeightsPtrs<Weight>* other2, Func func) {
void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
LayerWeightsPtrs<Weight>* other2, Func func) {
if (layer_config.type == LayerAttentionType::kVit) {
// MHA.
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
@ -248,9 +251,9 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
func(TENSOR_ARGS(griffin.conv_w, kMustRead));
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad));
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad));
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
func(TENSOR_ARGS(griffin.a, kMustRead));
}
@ -363,8 +366,8 @@ struct LayerWeightsPtrs {
// For FFN. Fast, only updates pointers.
void SplitW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2, and backprop/ allocates both.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
@ -514,10 +517,10 @@ struct ModelWeightsPtrs {
// used to copy from another set of weights. Public because called by tests
// and `WeightsOwner`.
template <class Func>
void ForEachTensor(const ModelWeightsPtrs<Weight>* other1,
const ModelWeightsPtrs<Weight>* other2, Func func) {
const LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
const LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
ModelWeightsPtrs<Weight>* other2, Func func) {
LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
func(TENSOR_ARGS(final_norm_scale, kMustRead));
@ -569,11 +572,12 @@ struct ModelWeightsPtrs {
// Copies only the allocated tensors in `*this` from tensors in `other`.
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
ForEachTensor(&other, nullptr, [](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
CopyMat(*t.other_mat1, t.mat);
});
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
[](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
CopyMat(*t.other_mat1, t.mat);
});
}
// Instead of reading, only allocates memory for all tensors. Used by

View File

@ -84,6 +84,14 @@ struct Header { // standard layout class
static_assert(sizeof(Header) == 16);
} // namespace
// A write I/O request, each serviced by one thread in a pool.
struct BlobIO {
BlobIO(BlobRange range, void* data) : range(range), data(data) {}
BlobRange range;
void* data; // Read-only for writes.
};
// Little-endian on-disk representation: a fixed-size `Header`, then a padded
// variable-length 'directory' of blob keys and their offset/sizes, then the
// 'payload' of each blob's data with padding in between, followed by padding to
@ -238,11 +246,11 @@ class BlobStore {
return true; // all OK
}
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO>& writes) const {
const size_t key_idx = 0; // not actually associated with a key/blob
writes.emplace_back(
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
// members are const and BlobIO2 requires non-const pointers, and they
// members are const and BlobIO requires non-const pointers, and they
// are not modified by file writes.
const_cast<Header*>(&header_));
writes.emplace_back(
@ -314,7 +322,7 @@ BlobReader::BlobReader(const Path& blob_path)
// Split into chunks for load-balancing even if blob sizes vary.
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
uint8_t* data, std::vector<BlobIO2>& writes) {
uint8_t* data, std::vector<BlobIO>& writes) {
constexpr size_t kChunkBytes = 4 * 1024 * 1024;
const uint64_t end = offset + bytes;
// Split into whole chunks and possibly one remainder.
@ -336,7 +344,7 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
static void EnqueueWritesForBlobs(const BlobStore& bs,
const hwy::Span<const uint8_t> blobs[],
std::vector<uint8_t>& zeros,
std::vector<BlobIO2>& writes) {
std::vector<BlobIO>& writes) {
// All-zero buffer used to write padding to the file without copying the
// input blobs.
static constexpr uint8_t kZeros[kBlobAlign] = {0};
@ -388,7 +396,7 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
HWY_ASSERT(num_blobs != 0);
HWY_ASSERT(num_blobs == blobs_.size());
std::vector<BlobIO2> writes;
std::vector<BlobIO> writes;
writes.reserve(16384);
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());

View File

@ -44,16 +44,6 @@ struct BlobRange {
size_t key_idx;
};
// A read or write I/O request, each serviced by one thread in a pool.
struct BlobIO2 {
BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
BlobRange range;
void* data; // Modified only if a read request. Read-only for writes.
};
class BlobStore;
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
// faster lookups.
// TODO(janwas): rename to BlobFinder or similar.

View File

@ -425,7 +425,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
}
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
if (!allocator.ShouldBind()) return;
if (B.Rows() == 1) return;
@ -437,8 +437,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
uintptr_t begin =
reinterpret_cast<uintptr_t>(B.RowT<uint8_t>(rows_b.begin()));
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that
// is page-aligned.
@ -451,7 +450,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
}
// C is BF16/float, or double for partial
void BindC(const MatPtr& C, MMParallel& parallel) {
void BindC(MatPtr& C, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
if (!allocator.ShouldBind()) return;
@ -470,8 +469,7 @@ void BindC(const MatPtr& C, MMParallel& parallel) {
const size_t node = parallel.Node(pkg_idx);
for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.MutableRowT<uint8_t>(im) + begin,
end - begin, node);
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
}
}
if (HWY_UNLIKELY(!ok)) {

View File

@ -172,9 +172,9 @@ class MMParallel {
ThreadingContext& ctx_;
};
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
// C is BF16/float, or double for partial.
void BindC(const MatPtr& C, MMParallel& parallel);
void BindC(MatPtr& C, MMParallel& parallel);
// Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K).

View File

@ -41,8 +41,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) {
}
const size_t row_bytes = to.Cols() * to.ElementBytes();
for (size_t r = 0; r < to.Rows(); ++r) {
const uint8_t* from_row = from.RowT<uint8_t>(r);
uint8_t* to_row = to.RowT<uint8_t>(r);
const uint8_t* from_row = from.RowBytes(r);
uint8_t* to_row = to.RowBytes(r);
hwy::CopyBytes(from_row, to_row, row_bytes);
}
}
@ -58,7 +58,7 @@ void ZeroInit(MatPtr& mat) {
}
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
for (size_t r = 0; r < mat.Rows(); ++r) {
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes);
hwy::ZeroBytes(mat.RowBytes(r), row_bytes);
}
}
@ -71,15 +71,19 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
std::normal_distribution<float> dist(0.0, stddev);
if (mat.GetType() == Type::kF32) {
MatPtrT<float> mat_f(mat);
for (size_t r = 0; r < mat.Rows(); ++r) {
float* HWY_RESTRICT row = mat.RowT<float>(r);
float* HWY_RESTRICT row = mat_f.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen);
}
}
} else {
MatPtrT<double> mat_d(mat);
for (size_t r = 0; r < mat.Rows(); ++r) {
double* HWY_RESTRICT row = mat.RowT<double>(r);
double* HWY_RESTRICT row = mat_d.Row(r);
for (size_t c = 0; c < mat.Cols(); ++c) {
row[c] = dist(gen);
}

View File

@ -91,21 +91,14 @@ class MatPtr : public IFields {
return num_elements_ * element_bytes_;
}
// Works for any kind of padding.
template <typename T>
T* MutableRowT(size_t row) const {
// Works for any kind of padding and element type.
uint8_t* RowBytes(size_t row) {
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
return static_cast<uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
}
template <typename T>
T* RowT(size_t row) {
const uint8_t* RowBytes(size_t row) const {
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
}
template <typename T>
const T* RowT(size_t row) const {
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
return static_cast<const uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
}
Type GetType() const { return type_; }
@ -209,9 +202,8 @@ class MatPtr : public IFields {
float scale_ = 1.0f; // multiplier for each value, for MatMul.
};
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides
// type-aware accessors (`RowT`), this class is more convenient when accessing
// elements, and ensures the template argument and `Type` are consistent.
// Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures
// the template argument and `Type` are consistent.
template <typename MatT>
class MatPtrT : public MatPtr {
public:
@ -227,7 +219,9 @@ class MatPtrT : public MatPtr {
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
// Copying allowed because the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {}
MatPtrT(const MatPtr& other) : MatPtr(other) {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
}
MatPtrT& operator=(const MatPtr& other) {
MatPtr::operator=(other);
return *this;
@ -252,22 +246,24 @@ class MatPtrT : public MatPtr {
return Packed();
}
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
const MatT* Row(size_t row) const {
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
}
PackedSpan<const MatT> PaddedSpan() const {
return MakeConstSpan(Row(0), Rows() * Stride());
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
}
// For `compress-inl.h` functions, which assume contiguous streams and thus
// require packed layout.
PackedSpan<const MatT> Span() const {
HWY_ASSERT(IsPacked());
return MakeConstSpan(Row(0), num_elements_);
}
PackedSpan<MatT> Span() {
HWY_ASSERT(IsPacked());
return MakeSpan(Row(0), num_elements_);
return MakeSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
}
PackedSpan<const MatT> Span() const {
HWY_ASSERT(IsPacked());
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
}
};