Internal change

PiperOrigin-RevId: 803083229
This commit is contained in:
Jan Wassenberg 2025-09-04 10:30:42 -07:00 committed by Copybara-Service
parent afd82376a5
commit 5d1693e806
5 changed files with 19 additions and 22 deletions

View File

@ -539,14 +539,14 @@ cc_library(
":weights", ":weights",
"//compression:compress", "//compression:compress",
"//compression:types", "//compression:types",
"//io:blob_store",
"//io", "//io",
"//io:blob_store",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@highway//:nanobenchmark", # timer "@highway//:nanobenchmark", # timer
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
], ],
) )

View File

@ -421,7 +421,7 @@ struct ModelConfig : public IFields {
} }
size_t KVCacheCols() const { size_t KVCacheCols() const {
size_t num_layers = layer_configs.size(); const size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize(); return num_layers * layer_configs[0].CacheLayerSize();
} }

View File

@ -556,17 +556,14 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
{
timing_info.generate_start = hwy::platform::Now(); timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env); Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, SampleAndStream(config, runtime_config, weights, sample_token, activations,
activations, qbatch, /*update_pos=*/true, env, non_eos, qbatch, /*update_pos=*/true, env, non_eos, timing_info);
timing_info);
} }
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
} }
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config, const ModelConfig& config,

View File

@ -226,13 +226,14 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
// ideally already happen in the importer. Called by `ReadFromBlobs`. // ideally already happen in the importer. Called by `ReadFromBlobs`.
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) { ThreadingContext& ctx) {
// TODO: use 1D parallel-for helper function const size_t cluster_idx = 0;
hwy::ThreadPool& pool = ctx.pools.Pool(); ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator); GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
}); });
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx.allocator); VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
}); });
} }

View File

@ -147,8 +147,7 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("image_size", &VitConfig::image_size) .def_readwrite("image_size", &VitConfig::image_size)
.def_readwrite("layer_configs", &VitConfig::layer_configs); .def_readwrite("layer_configs", &VitConfig::layer_configs);
class_<InternalModelConfig>(py_module, "InternalModelConfig") class_<InternalModelConfig>(py_module, "InternalModelConfig").def(init<>());
.def(init<>());
class_<ModelConfig>(py_module, "ModelConfig") class_<ModelConfig>(py_module, "ModelConfig")
.def(init<>()) .def(init<>())