Internal change

PiperOrigin-RevId: 803083229
This commit is contained in:
Jan Wassenberg 2025-09-04 10:30:42 -07:00 committed by Copybara-Service
parent afd82376a5
commit 5d1693e806
5 changed files with 19 additions and 22 deletions

View File

@ -539,14 +539,14 @@ cc_library(
":weights",
"//compression:compress",
"//compression:types",
"//io:blob_store",
"//io",
"//io:blob_store",
"//paligemma:image",
"@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)

View File

@ -421,7 +421,7 @@ struct ModelConfig : public IFields {
}
size_t KVCacheCols() const {
size_t num_layers = layer_configs.size();
const size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize();
}

View File

@ -556,17 +556,14 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
{
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token,
activations, qbatch, /*update_pos=*/true, env, non_eos,
timing_info);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, /*update_pos=*/true, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config,

View File

@ -226,13 +226,14 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
// ideally already happen in the importer. Called by `ReadFromBlobs`.
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
// TODO: use 1D parallel-for helper function
hwy::ThreadPool& pool = ctx.pools.Pool();
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
}

View File

@ -147,8 +147,7 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("image_size", &VitConfig::image_size)
.def_readwrite("layer_configs", &VitConfig::layer_configs);
class_<InternalModelConfig>(py_module, "InternalModelConfig")
.def(init<>());
class_<InternalModelConfig>(py_module, "InternalModelConfig").def(init<>());
class_<ModelConfig>(py_module, "ModelConfig")
.def(init<>())