diff --git a/BUILD.bazel b/BUILD.bazel index 62f2f5c..cbfb342 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -539,14 +539,14 @@ cc_library( ":weights", "//compression:compress", "//compression:types", - "//io:blob_store", "//io", + "//io:blob_store", "//paligemma:image", "@highway//:hwy", - "@highway//hwy/contrib/sort:vqsort", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", + "@highway//hwy/contrib/sort:vqsort", ], ) diff --git a/gemma/configs.h b/gemma/configs.h index a1cd902..0c93e30 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -421,7 +421,7 @@ struct ModelConfig : public IFields { } size_t KVCacheCols() const { - size_t num_layers = layer_configs.size(); + const size_t num_layers = layer_configs.size(); return num_layers * layer_configs[0].CacheLayerSize(); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a0949fe..a7e73ca 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -556,16 +556,13 @@ static void GenerateT(const ModelConfig& config, const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); - { - timing_info.generate_start = hwy::platform::Now(); - for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { - Transformer(config, runtime_config, weights, activations, qbatch, env); - SampleAndStream(config, runtime_config, weights, sample_token, - activations, qbatch, /*update_pos=*/true, env, non_eos, - timing_info); - } - timing_info.NotifyGenerateDone(); + timing_info.generate_start = hwy::platform::Now(); + for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { + Transformer(config, runtime_config, weights, activations, qbatch, env); + SampleAndStream(config, runtime_config, weights, sample_token, activations, + qbatch, /*update_pos=*/true, env, non_eos, timing_info); } + timing_info.NotifyGenerateDone(); } void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, diff --git a/gemma/weights.cc b/gemma/weights.cc index ca1cebc..3d1d43e 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -226,15 +226,16 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { // ideally already happen in the importer. Called by `ReadFromBlobs`. void WeightsPtrs::Fixup(std::vector& mat_owners, ThreadingContext& ctx) { - // TODO: use 1D parallel-for helper function - hwy::ThreadPool& pool = ctx.pools.Pool(); - pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Fixup(mat_owners, ctx.allocator); - }); + const size_t cluster_idx = 0; + ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, + [&](uint64_t layer, size_t /*worker*/) { + GetLayer(layer)->Fixup(mat_owners, ctx.allocator); + }); - pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - VitLayer(layer)->Fixup(mat_owners, ctx.allocator); - }); + ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, + [&](uint64_t layer, size_t /*worker*/) { + VitLayer(layer)->Fixup(mat_owners, ctx.allocator); + }); } std::vector WeightsPtrs::AddTensorDataToWriter( diff --git a/python/configs.cc b/python/configs.cc index 36cd314..f8121bf 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -147,8 +147,7 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("image_size", &VitConfig::image_size) .def_readwrite("layer_configs", &VitConfig::layer_configs); - class_(py_module, "InternalModelConfig") - .def(init<>()); + class_(py_module, "InternalModelConfig").def(init<>()); class_(py_module, "ModelConfig") .def(init<>())