Fix the link error when building `compress_weights` with Clang on macOS

This commit is contained in:
RangerUFO 2025-02-09 00:13:25 +08:00
parent b18bd781f6
commit 3a5a6dbcad
2 changed files with 46 additions and 37 deletions

View File

@ -257,8 +257,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
} }
} }
template <class Weight> template <>
void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) { void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
if (attn_vec_einsum_w.data() == nullptr) return; if (attn_vec_einsum_w.data() == nullptr) return;
const size_t model_dim = layer_config.model_dim; const size_t model_dim = layer_config.model_dim;
@ -271,48 +271,34 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
att_weights.SetPtr(*storage); att_weights.SetPtr(*storage);
} }
if (hwy::IsSame<Weight, NuqStream>()) { const hwy::HWY_NAMESPACE::ScalableTag<float> df;
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp = hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim); hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp = hwy::AlignedFreeUniquePtr<float[]> att_weights_tmp =
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim); hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
HWY_NAMESPACE::DecompressAndZeroPad( HWY_NAMESPACE::DecompressAndZeroPad(
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0, df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0,
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim); attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
for (size_t m = 0; m < model_dim; ++m) {
float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(float));
}
}
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
/*packed_ofs=*/0, pool);
att_weights.set_scale(attn_vec_einsum_w.scale());
return;
}
for (size_t m = 0; m < model_dim; ++m) { for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim; float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) { for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes( hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim, attn_vec_einsum_w_tmp.get() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight)); out_row + h * qkv_dim, qkv_dim * sizeof(float));
} }
} }
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
/*packed_ofs=*/0, pool);
att_weights.set_scale(attn_vec_einsum_w.scale()); att_weights.set_scale(attn_vec_einsum_w.scale());
} }

View File

@ -179,7 +179,30 @@ struct LayerWeightsPtrs {
// Initializes att_weights from attn_vec_einsum_w, hence this must be called // Initializes att_weights from attn_vec_einsum_w, hence this must be called
// after loading weights via ForEachTensor. // after loading weights via ForEachTensor.
// TODO: update compression/convert_weights to bake this in. // TODO: update compression/convert_weights to bake this in.
void Reshape(MatStorage* storage); void Reshape(MatStorage* storage) {
static_assert(!hwy::IsSame<Weight, NuqStream>());
if (attn_vec_einsum_w.data() == nullptr) return;
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
if (storage != nullptr) {
storage->Allocate();
att_weights.SetPtr(*storage);
}
for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
}
}
att_weights.set_scale(attn_vec_einsum_w.scale());
}
// Used by ForEachTensor for per-layer tensors. // Used by ForEachTensor for per-layer tensors.
#define GEMMA_CALL_FUNC(member) \ #define GEMMA_CALL_FUNC(member) \