mirror of https://github.com/google/gemma.cpp.git
Fix RowT issue and improve Griffin (currently still broken)
Use type-safe MatPtrT via dynamic_cast, avoid/remove unsafe RowT activations: Griffin tensors are now padded Griffin: add batching support, fix conv1d_cache allocation weights: bundle to TensorToRead, add kNoPad flag, fix SplitW1 const-correct fix for ForEachTensor blob_store: move BlobIO2 to .cc and rename BlobIO PiperOrigin-RevId: 760610094
This commit is contained in:
parent
d6cfabc2c1
commit
cb188d4a0e
|
|
@ -27,6 +27,8 @@ namespace gcpp {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using MatPtrF = MatPtrT<float>;
|
||||||
|
|
||||||
// Split into two classes so that ForEachTensor only requires two "other"
|
// Split into two classes so that ForEachTensor only requires two "other"
|
||||||
// arguments. This is anyway useful for locality, because `grad` only feeds
|
// arguments. This is anyway useful for locality, because `grad` only feeds
|
||||||
// into `grad_m` and `grad_v` here.
|
// into `grad_m` and `grad_v` here.
|
||||||
|
|
@ -40,12 +42,11 @@ class AdamUpdateMV {
|
||||||
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
|
norm2_(1.0 / (1.0 - std::pow(beta2, t))) {}
|
||||||
|
|
||||||
void operator()(const MatPtr& grad, const MatPtr& grad_m,
|
void operator()(const MatPtrF& grad, MatPtrF& grad_m, MatPtrF& grad_v) {
|
||||||
const MatPtr& grad_v) {
|
|
||||||
for (size_t r = 0; r < grad.Rows(); ++r) {
|
for (size_t r = 0; r < grad.Rows(); ++r) {
|
||||||
const float* HWY_RESTRICT g = grad.RowT<float>(r);
|
const float* HWY_RESTRICT g = grad.Row(r);
|
||||||
float* HWY_RESTRICT m = grad_m.MutableRowT<float>(r);
|
float* HWY_RESTRICT m = grad_m.Row(r);
|
||||||
float* HWY_RESTRICT v = grad_v.MutableRowT<float>(r);
|
float* HWY_RESTRICT v = grad_v.Row(r);
|
||||||
for (size_t c = 0; c < grad.Cols(); ++c) {
|
for (size_t c = 0; c < grad.Cols(); ++c) {
|
||||||
m[c] *= beta1_;
|
m[c] *= beta1_;
|
||||||
m[c] += cbeta1_ * g[c];
|
m[c] += cbeta1_ * g[c];
|
||||||
|
|
@ -73,11 +74,12 @@ class AdamUpdateW {
|
||||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
|
norm2_(1.0 / (1.0 - std::pow(beta2, t))),
|
||||||
epsilon_(epsilon) {}
|
epsilon_(epsilon) {}
|
||||||
|
|
||||||
void operator()(MatPtr& weights, const MatPtr& grad_m, const MatPtr& grad_v) {
|
void operator()(MatPtrF& weights, const MatPtrF& grad_m,
|
||||||
|
const MatPtrF& grad_v) {
|
||||||
for (size_t r = 0; r < weights.Rows(); ++r) {
|
for (size_t r = 0; r < weights.Rows(); ++r) {
|
||||||
float* HWY_RESTRICT w = weights.RowT<float>(r);
|
float* HWY_RESTRICT w = weights.Row(r);
|
||||||
const float* HWY_RESTRICT m = grad_m.RowT<float>(r);
|
const float* HWY_RESTRICT m = grad_m.Row(r);
|
||||||
const float* HWY_RESTRICT v = grad_v.RowT<float>(r);
|
const float* HWY_RESTRICT v = grad_v.Row(r);
|
||||||
for (size_t c = 0; c < weights.Cols(); ++c) {
|
for (size_t c = 0; c < weights.Cols(); ++c) {
|
||||||
const float mhat = m[c] * norm1_;
|
const float mhat = m[c] * norm1_;
|
||||||
const float vhat = v[c] * norm2_;
|
const float vhat = v[c] * norm2_;
|
||||||
|
|
@ -100,12 +102,18 @@ void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
|
||||||
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
||||||
AdamUpdateMV update_mv(beta1, beta2, t);
|
AdamUpdateMV update_mv(beta1, beta2, t);
|
||||||
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
|
grad->ForEachTensor(grad_m, grad_v, [&update_mv](const TensorArgs& t) {
|
||||||
update_mv(t.mat, *t.other_mat1, *t.other_mat2);
|
const MatPtrF grad_f(t.mat);
|
||||||
|
MatPtrF grad_m_f(*t.other_mat1);
|
||||||
|
MatPtrF grad_v_f(*t.other_mat2);
|
||||||
|
update_mv(grad_f, grad_m_f, grad_v_f);
|
||||||
});
|
});
|
||||||
|
|
||||||
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
|
AdamUpdateW update_w(alpha, beta1, beta2, epsilon, t);
|
||||||
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
|
weights->ForEachTensor(grad_m, grad_v, [&update_w](const TensorArgs& t) {
|
||||||
update_w(t.mat, *t.other_mat1, *t.other_mat2);
|
MatPtrF weights_f(t.mat);
|
||||||
|
const MatPtrF grad_m_f(*t.other_mat1);
|
||||||
|
const MatPtrF grad_v_f(*t.other_mat2);
|
||||||
|
update_w(weights_f, grad_m_f, grad_v_f);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,7 @@ struct Activations {
|
||||||
layer_config(config.layer_configs[0]),
|
layer_config(config.layer_configs[0]),
|
||||||
seq_len(config.seq_len),
|
seq_len(config.seq_len),
|
||||||
cache_pos_size(config.CachePosSize()),
|
cache_pos_size(config.CachePosSize()),
|
||||||
is_griffin(layer_config.type ==
|
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||||
LayerAttentionType::kGriffinRecurrentBlock),
|
|
||||||
|
|
||||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||||
q("q",
|
q("q",
|
||||||
|
|
@ -58,21 +57,18 @@ struct Activations {
|
||||||
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||||
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
||||||
|
|
||||||
// No padding for Griffin because it does not always use Row().
|
|
||||||
griffin_x("griffin_x",
|
griffin_x("griffin_x",
|
||||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||||
MatPadding::kPacked),
|
pad_),
|
||||||
griffin_y("griffin_y",
|
griffin_y("griffin_y",
|
||||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||||
MatPadding::kPacked),
|
pad_),
|
||||||
griffin_gate_x(
|
griffin_gate_x(
|
||||||
"griffin_gate_x",
|
"griffin_gate_x",
|
||||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
|
||||||
MatPadding::kPacked),
|
|
||||||
griffin_multiplier(
|
griffin_multiplier(
|
||||||
"griffin_mul",
|
"griffin_mul",
|
||||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
|
||||||
MatPadding::kPacked),
|
|
||||||
|
|
||||||
inv_timescale(CreateInvTimescale(
|
inv_timescale(CreateInvTimescale(
|
||||||
ThreadingContext::Get().allocator, layer_config.qkv_dim,
|
ThreadingContext::Get().allocator, layer_config.qkv_dim,
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm> // std::min
|
#include <algorithm> // std::min
|
||||||
#include <memory> // std::make_unique
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
|
|
@ -69,30 +68,44 @@ namespace HWY_NAMESPACE {
|
||||||
// Different functions use different naming conventions for the number of
|
// Different functions use different naming conventions for the number of
|
||||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||||
// `Attention`, use separate `num_tokens` and `num_queries`.
|
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||||
|
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||||
|
|
||||||
// TODO: add batch query support for Griffin (QueriesPos).
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
|
||||||
size_t layer, Activations& activations,
|
size_t num_tokens, size_t griffin_layer,
|
||||||
|
Activations& activations,
|
||||||
const LayerWeightsPtrs<T>* layer_weights,
|
const LayerWeightsPtrs<T>* layer_weights,
|
||||||
const KVCaches& kv_caches) {
|
const KVCaches& kv_caches) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
KVCache& kv_cache = kv_caches[0];
|
|
||||||
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
const D df;
|
||||||
|
|
||||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||||
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
|
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
||||||
|
|
||||||
const size_t heads = layer_weights->layer_config.heads;
|
const size_t heads = layer_weights->layer_config.heads;
|
||||||
|
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
|
||||||
|
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||||
|
const size_t kHeadDim = model_dim / heads;
|
||||||
|
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||||
|
|
||||||
|
const size_t num_queries = queries_pos.size();
|
||||||
|
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
|
||||||
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
|
|
||||||
// X / Y linear layers.
|
// X / Y linear layers.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
// TODO: MatMul
|
||||||
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
|
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
|
||||||
|
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||||
|
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
|
||||||
|
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||||
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
||||||
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
||||||
activations.pre_att_rms_out.Row(batch_idx),
|
activations.pre_att_rms_out.Row(r),
|
||||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||||
/*out0=*/x, /*out1=*/y, pool);
|
/*out0=*/x, /*out1=*/y, pool);
|
||||||
|
|
@ -100,18 +113,19 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conv1D.
|
// Conv1D.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
const size_t pos = batch_start + batch_idx;
|
++interleaved_idx) {
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
||||||
HWY_FULL(float) df;
|
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
||||||
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||||
|
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
||||||
|
|
||||||
// cache[i] = input at time t-i.
|
// cache[i] = input at time t-i.
|
||||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||||
cache[0] = x;
|
cache[0] = x;
|
||||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||||
cache[i] =
|
cache[i] =
|
||||||
kv_cache.conv1d_cache.Row(layer) +
|
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
|
||||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||||
|
|
@ -119,7 +133,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
auto accum0 =
|
auto accum0 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
||||||
auto accum1 = hn::Zero(df);
|
auto accum1 = hn::Zero(df);
|
||||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
|
||||||
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
||||||
auto wv0 =
|
auto wv0 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||||
|
|
@ -136,17 +149,20 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
// RGLRU
|
// RGLRU
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
const size_t pos = batch_start + batch_idx;
|
++interleaved_idx) {
|
||||||
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
|
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
||||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx);
|
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
|
|
||||||
float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer);
|
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
||||||
|
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
|
||||||
|
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
|
||||||
|
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
|
||||||
|
float* HWY_RESTRICT rnn_state =
|
||||||
|
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
|
||||||
|
|
||||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t kHeadDim = model_dim / heads;
|
|
||||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
|
||||||
size_t head_offset = head * kHeadDim;
|
size_t head_offset = head * kHeadDim;
|
||||||
TwoOfsMatVecAddLoop(
|
TwoOfsMatVecAddLoop(
|
||||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||||
|
|
@ -166,7 +182,6 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||||
fn_mul);
|
fn_mul);
|
||||||
// RNN scan
|
// RNN scan
|
||||||
HWY_FULL(float) df;
|
|
||||||
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
|
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
|
||||||
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
|
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
|
||||||
auto log_a = hn::Load(df, a + head_offset + i);
|
auto log_a = hn::Load(df, a + head_offset + i);
|
||||||
|
|
@ -186,17 +201,18 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
hn::Store(pre_out, df, x + head_offset + i);
|
hn::Store(pre_out, df, x + head_offset + i);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
} // interleaved_idx
|
||||||
|
|
||||||
// Final linear layer.
|
// Final linear layer.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
// TODO: MatMul
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
|
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||||
float* out_ptr = activations.att_sums.Row(batch_idx);
|
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||||
|
float* out_ptr = activations.att_sums.Row(r);
|
||||||
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
||||||
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
||||||
pool);
|
pool);
|
||||||
}
|
}
|
||||||
}
|
} // GriffinRecurrent
|
||||||
|
|
||||||
// Wrapper class; holds arguments in member variables to shorten call sites.
|
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
@ -219,11 +235,7 @@ class GemmaAttention {
|
||||||
activations_.weights_config.attention_window_sizes[layer] ==
|
activations_.weights_config.attention_window_sizes[layer] ==
|
||||||
activations_.seq_len;
|
activations_.seq_len;
|
||||||
// TODO: add a config flag instead of hardcoding the model.
|
// TODO: add a config flag instead of hardcoding the model.
|
||||||
if (is_global_layer &&
|
if (is_global_layer && IsVLM(activations_.weights_config.model)) {
|
||||||
(activations_.weights_config.model == Model::GEMMA3_4B ||
|
|
||||||
activations_.weights_config.model == Model::GEMMA3_12B ||
|
|
||||||
activations_.weights_config.model == Model::GEMMA3_27B ||
|
|
||||||
activations_.weights_config.model == Model::GEMMA3_1B)) {
|
|
||||||
inv_timescale = activations_.inv_timescale_global.Packed();
|
inv_timescale = activations_.inv_timescale_global.Packed();
|
||||||
}
|
}
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
|
|
@ -454,7 +466,6 @@ class GemmaAttention {
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
|
||||||
pos, start_pos, last_pos);
|
pos, start_pos, last_pos);
|
||||||
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -587,14 +598,12 @@ HWY_NOINLINE void Attention(
|
||||||
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
|
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
|
||||||
activations, layer_weights, div_seq_len, kv_caches)();
|
activations, layer_weights, div_seq_len, kv_caches)();
|
||||||
} else {
|
} else {
|
||||||
// Only reached if the model is Griffin.
|
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
// The kv_caches are allocated only for the griffin layers, so we need to
|
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
|
||||||
// map the layer index to the griffin layer index.
|
// so map `layer` to the Griffin layer index.
|
||||||
auto type = layer_weights->layer_config.type;
|
const size_t griffin_layer =
|
||||||
size_t layer_of_type =
|
|
||||||
activations.weights_config.NumLayersOfTypeBefore(type, layer);
|
activations.weights_config.NumLayersOfTypeBefore(type, layer);
|
||||||
HWY_ASSERT(queries_pos.size() == 1);
|
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
|
||||||
GriffinRecurrent(queries_pos[0], num_tokens, layer_of_type, activations,
|
|
||||||
layer_weights, kv_caches);
|
layer_weights, kv_caches);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1056,7 +1065,6 @@ HWY_NOINLINE void Prefill(
|
||||||
// threads to parallelizing over queries, but for simplicity we assign them
|
// threads to parallelizing over queries, but for simplicity we assign them
|
||||||
// all to MatMul.
|
// all to MatMul.
|
||||||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||||
HWY_DASSERT(max_tbatch_size <= activations.x.Rows());
|
|
||||||
|
|
||||||
// For each query. `qi` is within the batch, not the global query index.
|
// For each query. `qi` is within the batch, not the global query index.
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
|
@ -1397,6 +1405,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
||||||
const QueriesPos& queries_prefix_end,
|
const QueriesPos& queries_prefix_end,
|
||||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
|
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
|
||||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
if (queries_pos_in[i] == 0) {
|
if (queries_pos_in[i] == 0) {
|
||||||
|
|
@ -1510,17 +1519,10 @@ void GenerateBatchT(const ModelStore& model,
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
HWY_ASSERT(kv_caches.size() >= num_queries);
|
||||||
// Griffin does not support query batching.
|
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||||
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
|
||||||
for (const LayerConfig& layer_config : model.Config().layer_configs) {
|
|
||||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
|
||||||
max_qbatch_size = 1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t max_batch_size =
|
const size_t max_batch_size =
|
||||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||||
|
|
||||||
Activations activations(model.Config(), max_batch_size, env);
|
Activations activations(model.Config(), max_batch_size, env);
|
||||||
|
|
||||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||||
|
|
@ -1528,7 +1530,6 @@ void GenerateBatchT(const ModelStore& model,
|
||||||
// Generate one batch of tokens from `qbatch_size` queries.
|
// Generate one batch of tokens from `qbatch_size` queries.
|
||||||
const size_t qbatch_size =
|
const size_t qbatch_size =
|
||||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||||
activations.SetBatchSize(qbatch_size);
|
|
||||||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
||||||
qbatch_size);
|
qbatch_size);
|
||||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,9 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void KVCache::ZeroGriffinCache() {
|
void KVCache::ZeroGriffinCache() {
|
||||||
if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache);
|
if (griffin_layers == 0) return;
|
||||||
if (rglru_cache.HasPtr()) ZeroInit(rglru_cache);
|
ZeroInit(conv1d_cache);
|
||||||
|
ZeroInit(rglru_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||||
|
|
@ -34,7 +35,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||||
for (const auto& layer_config : config.layer_configs) {
|
for (const auto& layer_config : config.layer_configs) {
|
||||||
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
|
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
|
||||||
}
|
}
|
||||||
return conv1d_width == 0 ? 0 : conv1d_width - 1;
|
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
|
||||||
|
// hence allocate conv1d_width * model_dim total columns.
|
||||||
|
return conv1d_width * config.model_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||||
|
|
@ -42,12 +45,9 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||||
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
|
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
|
||||||
: griffin_layers(
|
: griffin_layers(
|
||||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
|
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
|
||||||
griffin_conv1d_cols(GriffinConv1dCols(config)),
|
conv1d_cache("conv1d_cache",
|
||||||
// TODO(patrickms): Add query batching support for Griffin.
|
Extents2D(griffin_layers, GriffinConv1dCols(config)),
|
||||||
conv1d_cache(
|
MatPadding::kOdd),
|
||||||
"conv1d_cache",
|
|
||||||
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
|
|
||||||
MatPadding::kOdd),
|
|
||||||
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
|
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
|
||||||
MatPadding::kOdd) {
|
MatPadding::kOdd) {
|
||||||
// TODO: move to MatStorageT.
|
// TODO: move to MatStorageT.
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ struct KVCache {
|
||||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||||
|
|
||||||
size_t griffin_layers = 0;
|
size_t griffin_layers = 0;
|
||||||
size_t griffin_conv1d_cols = 0;
|
|
||||||
// griffin_layers, griffin_conv1d_cols * config.model_dim
|
// griffin_layers, griffin_conv1d_cols * config.model_dim
|
||||||
MatStorageT<float> conv1d_cache;
|
MatStorageT<float> conv1d_cache;
|
||||||
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
|
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
|
||||||
|
|
|
||||||
105
gemma/weights.cc
105
gemma/weights.cc
|
|
@ -97,21 +97,27 @@ void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
|
||||||
SplitW1NUQ(layer_config);
|
SplitW1NUQ(layer_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct TensorToRead {
|
||||||
|
MatPtr* mat;
|
||||||
|
BlobRange range;
|
||||||
|
// Some tensors opt out of padding via kNoPad flags.
|
||||||
|
MatPadding padding;
|
||||||
|
};
|
||||||
|
|
||||||
// Allocates multiple in parallel and binds to NUMA nodes.
|
// Allocates multiple in parallel and binds to NUMA nodes.
|
||||||
static void AllocateAndBindAll(const std::vector<MatPtr*>& mats,
|
static void AllocateAndBindAll(const std::vector<TensorToRead>& tensors,
|
||||||
MatPadding padding,
|
|
||||||
std::vector<MatOwner>& owners,
|
std::vector<MatOwner>& owners,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
const size_t start = owners.size();
|
const size_t start = owners.size();
|
||||||
owners.resize(start + mats.size());
|
owners.resize(start + tensors.size());
|
||||||
|
|
||||||
MMParallel parallel(ThreadingContext::Get());
|
MMParallel parallel(ThreadingContext::Get());
|
||||||
|
|
||||||
// Allocate in parallel because faulting in large tensors is slow.
|
// Allocate in parallel because faulting in large tensors is slow.
|
||||||
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
|
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||||
owners[start + task].AllocateFor(*mats[task], padding);
|
owners[start + task].AllocateFor(*tensors[task].mat, tensors[task].padding);
|
||||||
// TODO(janwas): MatMul outputs will later also be BF16.
|
// TODO(janwas): MatMul outputs will later also be BF16.
|
||||||
BindB(*mats[task], sizeof(float), parallel);
|
BindB(*tensors[task].mat, sizeof(float), parallel);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -155,55 +161,54 @@ MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
||||||
return MapPtr();
|
return MapPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void MapAll(const std::vector<MatPtr*>& mats,
|
static void MapAll(const std::vector<TensorToRead>& tensors,
|
||||||
const std::vector<BlobRange>& ranges, const MapPtr& mapped) {
|
const MapPtr& mapped) {
|
||||||
PROFILER_ZONE("Startup.Weights.Map");
|
PROFILER_ZONE("Startup.Weights.Map");
|
||||||
for (size_t i = 0; i < mats.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
// SetPtr does not change the stride, but it is expected to be packed
|
// SetPtr does not change the stride, but it is expected to be packed
|
||||||
// because that is what Compress() writes to the file.
|
// because that is what Compress() writes to the file.
|
||||||
const size_t mat_bytes = mats[i]->PackedBytes();
|
const size_t mat_bytes = tensors[i].mat->PackedBytes();
|
||||||
// Ensure blob size matches that computed from metadata.
|
// Ensure blob size matches that computed from metadata.
|
||||||
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
|
HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name());
|
||||||
|
|
||||||
mats[i]->SetPtr(const_cast<uint8_t*>(mapped.get() + ranges[i].offset),
|
tensors[i].mat->SetPtr(
|
||||||
mats[i]->Stride());
|
const_cast<uint8_t*>(mapped.get() + tensors[i].range.offset),
|
||||||
|
tensors[i].mat->Stride());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IOBatch> MakeBatches(const std::vector<BlobRange>& ranges,
|
std::vector<IOBatch> MakeBatches(const std::vector<TensorToRead>& tensors,
|
||||||
const std::vector<MatPtr*>& mats,
|
|
||||||
const uint64_t file_bytes) {
|
const uint64_t file_bytes) {
|
||||||
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||||
// Batches must be contiguous but blobs are padded, hence at least one
|
// Batches must be contiguous but blobs are padded, hence at least one
|
||||||
// batch per tensor, and more when tensor rows exceed the batch size.
|
// batch per tensor, and more when tensor rows exceed the batch size.
|
||||||
std::vector<IOBatch> batches;
|
std::vector<IOBatch> batches;
|
||||||
batches.reserve(mats.size());
|
batches.reserve(tensors.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < mats.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
uint64_t offset = ranges[i].offset;
|
const BlobRange& range = tensors[i].range;
|
||||||
HWY_ASSERT(ranges[i].End() <= file_bytes);
|
MatPtr& mat = *tensors[i].mat;
|
||||||
|
uint64_t offset = range.offset;
|
||||||
|
HWY_ASSERT(range.End() <= file_bytes);
|
||||||
|
|
||||||
batches.emplace_back(offset, ranges[i].key_idx);
|
batches.emplace_back(offset, range.key_idx);
|
||||||
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
|
const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes();
|
||||||
// Caution, `RowT` requires knowledge of the actual type. We instead use
|
const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes();
|
||||||
// the first row, which is the same for any type, and advance the *byte*
|
uint8_t* row_bytes = mat.RowBytes(0);
|
||||||
// pointer by the *byte* stride.
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
|
if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch.
|
||||||
uint8_t* row = mats[i]->RowT<uint8_t>(0);
|
batches.emplace_back(offset, range.key_idx);
|
||||||
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
|
|
||||||
if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch.
|
|
||||||
batches.emplace_back(offset, ranges[i].key_idx);
|
|
||||||
// Adding to an empty batch is always successful.
|
// Adding to an empty batch is always successful.
|
||||||
HWY_ASSERT(batches.back().Add(row, file_bytes_per_row));
|
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
|
||||||
}
|
}
|
||||||
offset += file_bytes_per_row;
|
offset += file_bytes_per_row;
|
||||||
row += mem_stride_bytes;
|
row_bytes += mem_stride_bytes;
|
||||||
// Keep the in-memory row padding uninitialized so msan detects any use.
|
// Keep the in-memory row padding uninitialized so msan detects any use.
|
||||||
}
|
}
|
||||||
HWY_ASSERT(offset == ranges[i].End());
|
HWY_ASSERT(offset == range.End());
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_ASSERT(batches.size() >= mats.size());
|
HWY_ASSERT(batches.size() >= tensors.size());
|
||||||
return batches;
|
return batches;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -228,16 +233,14 @@ static void ReadBatches(const BlobReader& reader,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Aborts on error.
|
// Aborts on error.
|
||||||
static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
|
||||||
const std::vector<BlobRange>& ranges, Tristate map,
|
BlobReader& reader, Tristate map,
|
||||||
std::vector<MatOwner>& mat_owners,
|
std::vector<MatOwner>& mat_owners,
|
||||||
const MatPadding padding, hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
HWY_ASSERT(mats.size() == ranges.size());
|
|
||||||
|
|
||||||
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
||||||
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
|
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
|
||||||
if (mapped) {
|
if (mapped) {
|
||||||
MapAll(mats, ranges, mapped);
|
MapAll(tensors, mapped);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} // otherwise fall through to read mode
|
} // otherwise fall through to read mode
|
||||||
|
|
@ -245,32 +248,32 @@ static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.Weights.Allocate");
|
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||||
// NOTE: this changes the stride of `mats`!
|
// NOTE: this changes the stride of `mats`!
|
||||||
AllocateAndBindAll(mats, padding, mat_owners, pool);
|
AllocateAndBindAll(tensors, mat_owners, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<IOBatch> batches =
|
const std::vector<IOBatch> batches =
|
||||||
MakeBatches(ranges, mats, reader.file_bytes());
|
MakeBatches(tensors, reader.file_bytes());
|
||||||
ReadBatches(reader, batches, pool);
|
ReadBatches(reader, batches, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
Tristate map, hwy::ThreadPool& pool) {
|
Tristate map, hwy::ThreadPool& pool) {
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
std::vector<MatPtr*> mats;
|
std::vector<TensorToRead> tensors;
|
||||||
std::vector<BlobRange> ranges;
|
|
||||||
|
|
||||||
// Padding is inserted when reading row by row, except for NUQ tensors.
|
|
||||||
const MatPadding padding = MatPadding::kOdd;
|
|
||||||
|
|
||||||
AllocatePointer(model.Config());
|
AllocatePointer(model.Config());
|
||||||
|
|
||||||
// Enumerate all weights (negligible cost).
|
// Enumerate all weights (negligible cost).
|
||||||
CallT([&](const auto& weights) {
|
CallT([&](const auto& weights) {
|
||||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||||
|
const MatPadding padding = (t.flags & TensorArgs::kNoPad)
|
||||||
|
? MatPadding::kPacked
|
||||||
|
: MatPadding::kOdd;
|
||||||
size_t key_idx;
|
size_t key_idx;
|
||||||
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
||||||
mats.push_back(&t.mat);
|
tensors.push_back({.mat = &t.mat,
|
||||||
ranges.push_back(reader.Range(key_idx));
|
.range = reader.Range(key_idx),
|
||||||
|
.padding = padding});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
|
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
|
||||||
|
|
@ -278,7 +281,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool);
|
MapOrReadAll(tensors, reader, map, mat_owners_, pool);
|
||||||
|
|
||||||
Fixup(pool);
|
Fixup(pool);
|
||||||
}
|
}
|
||||||
|
|
@ -338,9 +341,9 @@ void WeightsOwner::LogWeightStatsF32() {
|
||||||
printf("[scale=%f] ", t.mat.Scale());
|
printf("[scale=%f] ", t.mat.Scale());
|
||||||
}
|
}
|
||||||
hwy::Stats stats;
|
hwy::Stats stats;
|
||||||
HWY_ASSERT(t.mat.GetType() == Type::kF32);
|
const MatPtrT<float> mat_f(t.mat);
|
||||||
for (size_t r = 0; r < t.mat.Rows(); ++r) {
|
for (size_t r = 0; r < t.mat.Rows(); ++r) {
|
||||||
const float* HWY_RESTRICT row = t.mat.RowT<float>(r);
|
const float* HWY_RESTRICT row = mat_f.Row(r);
|
||||||
for (size_t c = 0; c < t.mat.Cols(); ++c) {
|
for (size_t c = 0; c < t.mat.Cols(); ++c) {
|
||||||
stats.Notify(row[c]);
|
stats.Notify(row[c]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,24 +43,27 @@ struct TensorArgs {
|
||||||
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
|
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
|
||||||
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
|
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
|
||||||
// `flags` is a combination of zero or more `Flags`.
|
// `flags` is a combination of zero or more `Flags`.
|
||||||
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
|
TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags)
|
||||||
int flags)
|
|
||||||
: mat(mat),
|
: mat(mat),
|
||||||
other_mat1(other_mat1),
|
other_mat1(other_mat1),
|
||||||
other_mat2(other_mat2),
|
other_mat2(other_mat2),
|
||||||
flags(flags) {}
|
flags(flags) {}
|
||||||
|
|
||||||
MatPtr& mat;
|
MatPtr& mat;
|
||||||
const MatPtr* other_mat1; // either/both can be nullptr.
|
MatPtr* other_mat1; // either/both can be nullptr.
|
||||||
const MatPtr* other_mat2;
|
MatPtr* other_mat2;
|
||||||
|
|
||||||
// TODO: freestanding enum class instead? These are mutually exclusive.
|
|
||||||
enum Flags {
|
enum Flags {
|
||||||
// Read the tensor from the file and abort if it is not found.
|
// Default: Read the tensor from the file and abort if it is not found.
|
||||||
kMustRead = 0,
|
kMustRead = 0,
|
||||||
|
|
||||||
// Not an error if the tensor is not present in the file. For example,
|
// Not an error if the tensor is not present in the file. For example,
|
||||||
// the _w1/_w2 tensors are not always present.
|
// the _w1/_w2 tensors are not always present.
|
||||||
kMaybeRead = 1,
|
kMaybeRead = 1,
|
||||||
|
|
||||||
|
// Avoid padding tensor rows when reading. Used for some Griffin tensors
|
||||||
|
// whose index computations do not use Row() accessors.
|
||||||
|
kNoPad = 2,
|
||||||
};
|
};
|
||||||
const int flags;
|
const int flags;
|
||||||
};
|
};
|
||||||
|
|
@ -214,8 +217,8 @@ struct LayerWeightsPtrs {
|
||||||
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
||||||
// Public because also called by `WeightsPtrs`.
|
// Public because also called by `WeightsPtrs`.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForEachTensor(const LayerWeightsPtrs<Weight>* other1,
|
void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
|
||||||
const LayerWeightsPtrs<Weight>* other2, Func func) {
|
LayerWeightsPtrs<Weight>* other2, Func func) {
|
||||||
if (layer_config.type == LayerAttentionType::kVit) {
|
if (layer_config.type == LayerAttentionType::kVit) {
|
||||||
// MHA.
|
// MHA.
|
||||||
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
||||||
|
|
@ -248,9 +251,9 @@ struct LayerWeightsPtrs {
|
||||||
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
|
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
|
||||||
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
|
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
|
||||||
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
|
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
|
||||||
func(TENSOR_ARGS(griffin.conv_w, kMustRead));
|
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad));
|
||||||
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
|
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
|
||||||
func(TENSOR_ARGS(griffin.gate_w, kMustRead));
|
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad));
|
||||||
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
|
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
|
||||||
func(TENSOR_ARGS(griffin.a, kMustRead));
|
func(TENSOR_ARGS(griffin.a, kMustRead));
|
||||||
}
|
}
|
||||||
|
|
@ -363,8 +366,8 @@ struct LayerWeightsPtrs {
|
||||||
|
|
||||||
// For FFN. Fast, only updates pointers.
|
// For FFN. Fast, only updates pointers.
|
||||||
void SplitW1() {
|
void SplitW1() {
|
||||||
// We only use this tensor for Gemma layers.
|
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
||||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
if (layer_config.type == LayerAttentionType::kVit) return;
|
||||||
|
|
||||||
// Files have both or neither of w1 and w2, and backprop/ allocates both.
|
// Files have both or neither of w1 and w2, and backprop/ allocates both.
|
||||||
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||||
|
|
@ -514,10 +517,10 @@ struct ModelWeightsPtrs {
|
||||||
// used to copy from another set of weights. Public because called by tests
|
// used to copy from another set of weights. Public because called by tests
|
||||||
// and `WeightsOwner`.
|
// and `WeightsOwner`.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForEachTensor(const ModelWeightsPtrs<Weight>* other1,
|
void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
|
||||||
const ModelWeightsPtrs<Weight>* other2, Func func) {
|
ModelWeightsPtrs<Weight>* other2, Func func) {
|
||||||
const LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
||||||
const LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
||||||
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
||||||
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
||||||
|
|
||||||
|
|
@ -569,11 +572,12 @@ struct ModelWeightsPtrs {
|
||||||
|
|
||||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||||
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
||||||
ForEachTensor(&other, nullptr, [](const TensorArgs& t) {
|
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
|
||||||
if (!t.mat.HasPtr()) return;
|
[](const TensorArgs& t) {
|
||||||
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
if (!t.mat.HasPtr()) return;
|
||||||
CopyMat(*t.other_mat1, t.mat);
|
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
||||||
});
|
CopyMat(*t.other_mat1, t.mat);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instead of reading, only allocates memory for all tensors. Used by
|
// Instead of reading, only allocates memory for all tensors. Used by
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,14 @@ struct Header { // standard layout class
|
||||||
static_assert(sizeof(Header) == 16);
|
static_assert(sizeof(Header) == 16);
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// A write I/O request, each serviced by one thread in a pool.
|
||||||
|
struct BlobIO {
|
||||||
|
BlobIO(BlobRange range, void* data) : range(range), data(data) {}
|
||||||
|
|
||||||
|
BlobRange range;
|
||||||
|
void* data; // Read-only for writes.
|
||||||
|
};
|
||||||
|
|
||||||
// Little-endian on-disk representation: a fixed-size `Header`, then a padded
|
// Little-endian on-disk representation: a fixed-size `Header`, then a padded
|
||||||
// variable-length 'directory' of blob keys and their offset/sizes, then the
|
// variable-length 'directory' of blob keys and their offset/sizes, then the
|
||||||
// 'payload' of each blob's data with padding in between, followed by padding to
|
// 'payload' of each blob's data with padding in between, followed by padding to
|
||||||
|
|
@ -238,11 +246,11 @@ class BlobStore {
|
||||||
return true; // all OK
|
return true; // all OK
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
|
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO>& writes) const {
|
||||||
const size_t key_idx = 0; // not actually associated with a key/blob
|
const size_t key_idx = 0; // not actually associated with a key/blob
|
||||||
writes.emplace_back(
|
writes.emplace_back(
|
||||||
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
|
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
|
||||||
// members are const and BlobIO2 requires non-const pointers, and they
|
// members are const and BlobIO requires non-const pointers, and they
|
||||||
// are not modified by file writes.
|
// are not modified by file writes.
|
||||||
const_cast<Header*>(&header_));
|
const_cast<Header*>(&header_));
|
||||||
writes.emplace_back(
|
writes.emplace_back(
|
||||||
|
|
@ -314,7 +322,7 @@ BlobReader::BlobReader(const Path& blob_path)
|
||||||
|
|
||||||
// Split into chunks for load-balancing even if blob sizes vary.
|
// Split into chunks for load-balancing even if blob sizes vary.
|
||||||
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
||||||
uint8_t* data, std::vector<BlobIO2>& writes) {
|
uint8_t* data, std::vector<BlobIO>& writes) {
|
||||||
constexpr size_t kChunkBytes = 4 * 1024 * 1024;
|
constexpr size_t kChunkBytes = 4 * 1024 * 1024;
|
||||||
const uint64_t end = offset + bytes;
|
const uint64_t end = offset + bytes;
|
||||||
// Split into whole chunks and possibly one remainder.
|
// Split into whole chunks and possibly one remainder.
|
||||||
|
|
@ -336,7 +344,7 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
||||||
static void EnqueueWritesForBlobs(const BlobStore& bs,
|
static void EnqueueWritesForBlobs(const BlobStore& bs,
|
||||||
const hwy::Span<const uint8_t> blobs[],
|
const hwy::Span<const uint8_t> blobs[],
|
||||||
std::vector<uint8_t>& zeros,
|
std::vector<uint8_t>& zeros,
|
||||||
std::vector<BlobIO2>& writes) {
|
std::vector<BlobIO>& writes) {
|
||||||
// All-zero buffer used to write padding to the file without copying the
|
// All-zero buffer used to write padding to the file without copying the
|
||||||
// input blobs.
|
// input blobs.
|
||||||
static constexpr uint8_t kZeros[kBlobAlign] = {0};
|
static constexpr uint8_t kZeros[kBlobAlign] = {0};
|
||||||
|
|
@ -388,7 +396,7 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||||
HWY_ASSERT(num_blobs != 0);
|
HWY_ASSERT(num_blobs != 0);
|
||||||
HWY_ASSERT(num_blobs == blobs_.size());
|
HWY_ASSERT(num_blobs == blobs_.size());
|
||||||
|
|
||||||
std::vector<BlobIO2> writes;
|
std::vector<BlobIO> writes;
|
||||||
writes.reserve(16384);
|
writes.reserve(16384);
|
||||||
|
|
||||||
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());
|
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());
|
||||||
|
|
|
||||||
|
|
@ -44,16 +44,6 @@ struct BlobRange {
|
||||||
size_t key_idx;
|
size_t key_idx;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A read or write I/O request, each serviced by one thread in a pool.
|
|
||||||
struct BlobIO2 {
|
|
||||||
BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
|
|
||||||
|
|
||||||
BlobRange range;
|
|
||||||
void* data; // Modified only if a read request. Read-only for writes.
|
|
||||||
};
|
|
||||||
|
|
||||||
class BlobStore;
|
|
||||||
|
|
||||||
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
|
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
|
||||||
// faster lookups.
|
// faster lookups.
|
||||||
// TODO(janwas): rename to BlobFinder or similar.
|
// TODO(janwas): rename to BlobFinder or similar.
|
||||||
|
|
|
||||||
|
|
@ -425,7 +425,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||||
Allocator& allocator = parallel.allocator();
|
Allocator& allocator = parallel.allocator();
|
||||||
if (!allocator.ShouldBind()) return;
|
if (!allocator.ShouldBind()) return;
|
||||||
if (B.Rows() == 1) return;
|
if (B.Rows() == 1) return;
|
||||||
|
|
@ -437,8 +437,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
uintptr_t begin =
|
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
|
||||||
reinterpret_cast<uintptr_t>(B.RowT<uint8_t>(rows_b.begin()));
|
|
||||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
||||||
// B row padding is less than the page size, so only bind the subset that
|
// B row padding is less than the page size, so only bind the subset that
|
||||||
// is page-aligned.
|
// is page-aligned.
|
||||||
|
|
@ -451,7 +450,7 @@ void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// C is BF16/float, or double for partial
|
// C is BF16/float, or double for partial
|
||||||
void BindC(const MatPtr& C, MMParallel& parallel) {
|
void BindC(MatPtr& C, MMParallel& parallel) {
|
||||||
Allocator& allocator = parallel.allocator();
|
Allocator& allocator = parallel.allocator();
|
||||||
if (!allocator.ShouldBind()) return;
|
if (!allocator.ShouldBind()) return;
|
||||||
|
|
||||||
|
|
@ -470,8 +469,7 @@ void BindC(const MatPtr& C, MMParallel& parallel) {
|
||||||
|
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
for (size_t im = 0; im < C.Rows(); ++im) {
|
for (size_t im = 0; im < C.Rows(); ++im) {
|
||||||
ok &= allocator.BindMemory(C.MutableRowT<uint8_t>(im) + begin,
|
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
|
||||||
end - begin, node);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(!ok)) {
|
if (HWY_UNLIKELY(!ok)) {
|
||||||
|
|
|
||||||
|
|
@ -172,9 +172,9 @@ class MMParallel {
|
||||||
ThreadingContext& ctx_;
|
ThreadingContext& ctx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
||||||
// C is BF16/float, or double for partial.
|
// C is BF16/float, or double for partial.
|
||||||
void BindC(const MatPtr& C, MMParallel& parallel);
|
void BindC(MatPtr& C, MMParallel& parallel);
|
||||||
|
|
||||||
// Per-package storage for packed A, and one global C-shaped `partial` for
|
// Per-package storage for packed A, and one global C-shaped `partial` for
|
||||||
// accumulating partial dot products (sections of K).
|
// accumulating partial dot products (sections of K).
|
||||||
|
|
|
||||||
14
util/mat.cc
14
util/mat.cc
|
|
@ -41,8 +41,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) {
|
||||||
}
|
}
|
||||||
const size_t row_bytes = to.Cols() * to.ElementBytes();
|
const size_t row_bytes = to.Cols() * to.ElementBytes();
|
||||||
for (size_t r = 0; r < to.Rows(); ++r) {
|
for (size_t r = 0; r < to.Rows(); ++r) {
|
||||||
const uint8_t* from_row = from.RowT<uint8_t>(r);
|
const uint8_t* from_row = from.RowBytes(r);
|
||||||
uint8_t* to_row = to.RowT<uint8_t>(r);
|
uint8_t* to_row = to.RowBytes(r);
|
||||||
hwy::CopyBytes(from_row, to_row, row_bytes);
|
hwy::CopyBytes(from_row, to_row, row_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -58,7 +58,7 @@ void ZeroInit(MatPtr& mat) {
|
||||||
}
|
}
|
||||||
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
|
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
|
||||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes);
|
hwy::ZeroBytes(mat.RowBytes(r), row_bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -71,15 +71,19 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
|
||||||
|
|
||||||
std::normal_distribution<float> dist(0.0, stddev);
|
std::normal_distribution<float> dist(0.0, stddev);
|
||||||
if (mat.GetType() == Type::kF32) {
|
if (mat.GetType() == Type::kF32) {
|
||||||
|
MatPtrT<float> mat_f(mat);
|
||||||
|
|
||||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
float* HWY_RESTRICT row = mat.RowT<float>(r);
|
float* HWY_RESTRICT row = mat_f.Row(r);
|
||||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||||
row[c] = dist(gen);
|
row[c] = dist(gen);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
MatPtrT<double> mat_d(mat);
|
||||||
|
|
||||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
double* HWY_RESTRICT row = mat.RowT<double>(r);
|
double* HWY_RESTRICT row = mat_d.Row(r);
|
||||||
for (size_t c = 0; c < mat.Cols(); ++c) {
|
for (size_t c = 0; c < mat.Cols(); ++c) {
|
||||||
row[c] = dist(gen);
|
row[c] = dist(gen);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
44
util/mat.h
44
util/mat.h
|
|
@ -91,21 +91,14 @@ class MatPtr : public IFields {
|
||||||
return num_elements_ * element_bytes_;
|
return num_elements_ * element_bytes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Works for any kind of padding.
|
// Works for any kind of padding and element type.
|
||||||
template <typename T>
|
uint8_t* RowBytes(size_t row) {
|
||||||
T* MutableRowT(size_t row) const {
|
|
||||||
HWY_DASSERT(row < Rows());
|
HWY_DASSERT(row < Rows());
|
||||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
return static_cast<uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
|
||||||
}
|
}
|
||||||
template <typename T>
|
const uint8_t* RowBytes(size_t row) const {
|
||||||
T* RowT(size_t row) {
|
|
||||||
HWY_DASSERT(row < Rows());
|
HWY_DASSERT(row < Rows());
|
||||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
return static_cast<const uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
const T* RowT(size_t row) const {
|
|
||||||
HWY_DASSERT(row < Rows());
|
|
||||||
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type GetType() const { return type_; }
|
Type GetType() const { return type_; }
|
||||||
|
|
@ -209,9 +202,8 @@ class MatPtr : public IFields {
|
||||||
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides
|
// Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures
|
||||||
// type-aware accessors (`RowT`), this class is more convenient when accessing
|
// the template argument and `Type` are consistent.
|
||||||
// elements, and ensures the template argument and `Type` are consistent.
|
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
class MatPtrT : public MatPtr {
|
class MatPtrT : public MatPtr {
|
||||||
public:
|
public:
|
||||||
|
|
@ -227,7 +219,9 @@ class MatPtrT : public MatPtr {
|
||||||
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
|
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
|
||||||
|
|
||||||
// Copying allowed because the metadata is small.
|
// Copying allowed because the metadata is small.
|
||||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
MatPtrT(const MatPtr& other) : MatPtr(other) {
|
||||||
|
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
|
||||||
|
}
|
||||||
MatPtrT& operator=(const MatPtr& other) {
|
MatPtrT& operator=(const MatPtr& other) {
|
||||||
MatPtr::operator=(other);
|
MatPtr::operator=(other);
|
||||||
return *this;
|
return *this;
|
||||||
|
|
@ -252,22 +246,24 @@ class MatPtrT : public MatPtr {
|
||||||
return Packed();
|
return Packed();
|
||||||
}
|
}
|
||||||
|
|
||||||
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
|
MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
|
||||||
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
|
const MatT* Row(size_t row) const {
|
||||||
|
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
|
||||||
|
}
|
||||||
|
|
||||||
PackedSpan<const MatT> PaddedSpan() const {
|
PackedSpan<const MatT> PaddedSpan() const {
|
||||||
return MakeConstSpan(Row(0), Rows() * Stride());
|
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), Rows() * Stride());
|
||||||
}
|
}
|
||||||
|
|
||||||
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
||||||
// require packed layout.
|
// require packed layout.
|
||||||
PackedSpan<const MatT> Span() const {
|
|
||||||
HWY_ASSERT(IsPacked());
|
|
||||||
return MakeConstSpan(Row(0), num_elements_);
|
|
||||||
}
|
|
||||||
PackedSpan<MatT> Span() {
|
PackedSpan<MatT> Span() {
|
||||||
HWY_ASSERT(IsPacked());
|
HWY_ASSERT(IsPacked());
|
||||||
return MakeSpan(Row(0), num_elements_);
|
return MakeSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
|
||||||
|
}
|
||||||
|
PackedSpan<const MatT> Span() const {
|
||||||
|
HWY_ASSERT(IsPacked());
|
||||||
|
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue