#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ #include #include #include #include #include "compression/shared.h" #include "gemma/configs.h" namespace gcpp { // Universal tensor information. Holds enough information to construct a // tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the // tensor from the python model with necessary transpose/reshape info. struct TensorInfo { // The name of the tensor in the sbs file std::string name; // Strings to match to the end of the name of the tensor in the python model. std::vector source_names; // Initial reshape shape. Use only as a last resort when input may have // dimensions combined that need to be split before the transpose, as it // defeats the post-transpose shape checking. Normally empty. std::vector preshape; // Transpose axes arg. If the input tensor has more dimensions than axes, // then leading dimensions are collapsed until the number of axes matches. std::vector axes; // Expected final shape of the tensor after reshape/transpose. // Note that this is the shape of the tensor during export, // not the shape of the tensor in the sbs file, as the sbs file // is restricted to 2D tensors. With few exceptions, the sbs file // tensor rows gather all the excess dimensions. See cols_take_extra_dims. std::vector shape; // List of names to concatenate with, used only if multiple tensors are // concatenated into one. The first tensor in the concatenation should have // concat names thus: The first name is the name of the result, and the // tensors with the remaining names are concatenated after this. // The remaining tensors to be concatenated should have just a single // empty string in concat_names to indicate that they have been consumed. std::vector concat_names; // Axis at which to concatenate. size_t concat_axis = 0; // The minimum compression weight type for this tensor. The default is // kNUQ, which provides maximum compression. Other values such as kBF16 // or kF32 can be used to limit the compression to a specific type. Type min_size = Type::kNUQ; // Whether to apply scaled softplus to the data. bool scaled_softplus = false; // Whether the columns or the rows take any extra dimensions. // If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30]. // If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30]. bool cols_take_extra_dims = false; }; // Universal index of tensor information, which can be built for a specific // layer_idx. class TensorIndex { public: // Builds a list of TensorInfo for the given layer_idx. // If reshape_att is true, the attn_vec_einsum tensor is reshaped. TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx, bool reshape_att); ~TensorIndex() = default; // Returns the TensorInfo whose source_name matches the end of the given path, // or an empty TensorInfo if not found. // NOTE: that the returned TensorInfo is a copy, so that the source // TensorIndex can be destroyed without affecting the returned TensorInfo. TensorInfo TensorInfoFromSourcePath(const std::string& path) const; // Returns the TensorInfo whose name matches the given name, // or an empty TensorInfo if not found. // NOTE: that the returned TensorInfo is a copy, so that the source // TensorIndex can be destroyed without affecting the returned TensorInfo. TensorInfo TensorInfoFromName(const std::string& name) const { const TensorInfo* info = FindName(name); if (info == nullptr) return TensorInfo(); return *info; } // Returns the TensorInfo for the given tensor name, for concise construction // of ModelWeightsPtrs/LayerWeightsPtrs. const TensorInfo* FindName(const std::string& name) const; private: // Config that was used to build the tensor index. const ModelConfig& config_; // Layer that this tensor index is for - either LLM or image. int llm_layer_idx_; int img_layer_idx_; // List of tensor information for this layer. std::vector tensors_; // Map from tensor name to index in tensors_. std::unordered_map name_map_; }; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_