// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "gemma/kv_cache.h" #include // std::copy #include "gemma/configs.h" #include "util/mat.h" // ZeroInit #include "hwy/aligned_allocator.h" #include "hwy/base.h" // ZeroBytes namespace gcpp { void KVCache::ZeroGriffinCache() { if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache); if (rglru_cache.HasPtr()) ZeroInit(rglru_cache); } static size_t GriffinConv1dCols(const ModelConfig& config) { size_t conv1d_width = 0; for (const auto& layer_config : config.layer_configs) { conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); } return conv1d_width == 0 ? 0 : conv1d_width - 1; } // prefill_tbatch_size is the maximum number of tokens from one query to // prefill at a time. KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size) : griffin_layers( config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)), griffin_conv1d_cols(GriffinConv1dCols(config)), // TODO(patrickms): Add query batching support for Griffin. conv1d_cache( "conv1d_cache", Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim), MatPadding::kOdd), rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim), MatPadding::kOdd) { // TODO: move to MatStorageT. const size_t size_cache_pos = config.CachePosSize(); if (size_cache_pos != 0) { // Allocate more so that prefill can always access one batch, even if // near the end of the sequence. seq_len = config.seq_len + prefill_tbatch_size; kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); } } KVCache KVCache::Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size) { KVCache copy(weights_config, prefill_tbatch_size); const size_t size_cache_pos = weights_config.CachePosSize(); if (size_cache_pos != 0) { std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, copy.kv_cache.get()); } if (conv1d_cache.HasPtr()) { CopyMat(conv1d_cache, copy.conv1d_cache); } if (rglru_cache.HasPtr()) { CopyMat(rglru_cache, copy.rglru_cache); } return copy; } } // namespace gcpp