mirror of https://github.com/google/gemma.cpp.git
parent
afd82376a5
commit
5d1693e806
|
|
@ -539,14 +539,14 @@ cc_library(
|
|||
":weights",
|
||||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"//io:blob_store",
|
||||
"//io",
|
||||
"//io:blob_store",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -421,7 +421,7 @@ struct ModelConfig : public IFields {
|
|||
}
|
||||
|
||||
size_t KVCacheCols() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
const size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -556,17 +556,14 @@ static void GenerateT(const ModelConfig& config,
|
|||
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
|
||||
|
||||
{
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
SampleAndStream(config, runtime_config, weights, sample_token,
|
||||
activations, qbatch, /*update_pos=*/true, env, non_eos,
|
||||
timing_info);
|
||||
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, /*update_pos=*/true, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const ModelConfig& config,
|
||||
|
|
|
|||
|
|
@ -226,13 +226,14 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
|||
// ideally already happen in the importer. Called by `ReadFromBlobs`.
|
||||
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
ThreadingContext& ctx) {
|
||||
// TODO: use 1D parallel-for helper function
|
||||
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
const size_t cluster_idx = 0;
|
||||
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||
[&](uint64_t layer, size_t /*worker*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
|
||||
});
|
||||
|
||||
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||
[&](uint64_t layer, size_t /*worker*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -147,8 +147,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("image_size", &VitConfig::image_size)
|
||||
.def_readwrite("layer_configs", &VitConfig::layer_configs);
|
||||
|
||||
class_<InternalModelConfig>(py_module, "InternalModelConfig")
|
||||
.def(init<>());
|
||||
class_<InternalModelConfig>(py_module, "InternalModelConfig").def(init<>());
|
||||
|
||||
class_<ModelConfig>(py_module, "ModelConfig")
|
||||
.def(init<>())
|
||||
|
|
|
|||
Loading…
Reference in New Issue