Internal change

PiperOrigin-RevId: 803083229
This commit is contained in:
Jan Wassenberg 2025-09-04 10:30:42 -07:00 committed by Copybara-Service
parent afd82376a5
commit 5d1693e806
5 changed files with 19 additions and 22 deletions

View File

@ -539,14 +539,14 @@ cc_library(
":weights", ":weights",
"//compression:compress", "//compression:compress",
"//compression:types", "//compression:types",
"//io:blob_store",
"//io", "//io",
"//io:blob_store",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@highway//:nanobenchmark", # timer "@highway//:nanobenchmark", # timer
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
], ],
) )

View File

@ -421,7 +421,7 @@ struct ModelConfig : public IFields {
} }
size_t KVCacheCols() const { size_t KVCacheCols() const {
size_t num_layers = layer_configs.size(); const size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize(); return num_layers * layer_configs[0].CacheLayerSize();
} }

View File

@ -556,16 +556,13 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
{ timing_info.generate_start = hwy::platform::Now();
timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { Transformer(config, runtime_config, weights, activations, qbatch, env);
Transformer(config, runtime_config, weights, activations, qbatch, env); SampleAndStream(config, runtime_config, weights, sample_token, activations,
SampleAndStream(config, runtime_config, weights, sample_token, qbatch, /*update_pos=*/true, env, non_eos, timing_info);
activations, qbatch, /*update_pos=*/true, env, non_eos,
timing_info);
}
timing_info.NotifyGenerateDone();
} }
timing_info.NotifyGenerateDone();
} }
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,

View File

@ -226,15 +226,16 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
// ideally already happen in the importer. Called by `ReadFromBlobs`. // ideally already happen in the importer. Called by `ReadFromBlobs`.
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) { ThreadingContext& ctx) {
// TODO: use 1D parallel-for helper function const size_t cluster_idx = 0;
hwy::ThreadPool& pool = ctx.pools.Pool(); ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator); GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
}); });
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
VitLayer(layer)->Fixup(mat_owners, ctx.allocator); [&](uint64_t layer, size_t /*worker*/) {
}); VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
});
} }
std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter( std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter(

View File

@ -147,8 +147,7 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("image_size", &VitConfig::image_size) .def_readwrite("image_size", &VitConfig::image_size)
.def_readwrite("layer_configs", &VitConfig::layer_configs); .def_readwrite("layer_configs", &VitConfig::layer_configs);
class_<InternalModelConfig>(py_module, "InternalModelConfig") class_<InternalModelConfig>(py_module, "InternalModelConfig").def(init<>());
.def(init<>());
class_<ModelConfig>(py_module, "ModelConfig") class_<ModelConfig>(py_module, "ModelConfig")
.def(init<>()) .def(init<>())