diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..f6cb8ad --- /dev/null +++ b/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: Google diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 929e140..82b9152 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,22 +1,30 @@ -name: Build +name: build -# Trigger on push or via manual dispatch. -on: [push, workflow_dispatch] +# Trigger on push, pull request, or via manual dispatch. +on: [push, pull_request, workflow_dispatch] jobs: build: - runs-on: ${{matrix.os}} - name: ${{ matrix.os }} ${{ matrix.type }} + runs-on: ${{ matrix.os }} + name: ${{ matrix.os }} (${{ matrix.preset }}) ${{ matrix.build_type }} timeout-minutes: 30 strategy: fail-fast: false matrix: - type: ['Release'] - os: ['ubuntu-latest'] + os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] + build_type: ['Release'] + preset: ['make', 'windows'] + exclude: + - os: ubuntu-latest + preset: windows + - os: macos-latest + preset: windows + - os: windows-latest + preset: make concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.preset }}-${{ matrix.build_type }} cancel-in-progress: true steps: @@ -26,20 +34,23 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2 - # Install CMake - - uses: lukka/get-cmake@latest + - name: Configure CMake + run: > + cmake --preset ${{ matrix.preset }} + -S ${{ github.workspace }} -B ${{ github.workspace }}/build + -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} + -D CMAKE_C_COMPILER_LAUNCHER=ccache + -D CMAKE_CXX_COMPILER_LAUNCHER=ccache - # Build via CMake - # Reference: https://github.com/lukka/run-cmake/blob/v3/action.yml - - name: Build via cmake - uses: lukka/run-cmake@v3 + - name: Build + run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} + + - name: Archive production artifacts + uses: actions/upload-artifact@v4 with: - cmakeListsOrSettingsJson: CMakeListsTxtAdvanced - cmakeAppendedArgs: > - -D CMAKE_C_COMPILER_LAUNCHER=ccache - -D CMAKE_CXX_COMPILER_LAUNCHER=ccache - buildWithCMake: true - # Explicitly list build targets here. - # Building "all" includes test executables and takes much longer. - buildWithCMakeArgs: "-- gemma" - buildDirectory: '${{ github.workspace }}/build' + name: gemma-${{ matrix.os }}-${{ matrix.preset }}-${{ matrix.build_type }} + path: | + ${{ github.workspace }}/build/${{ matrix.build_type }}/gemma.exe + ${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib + ${{ github.workspace }}/build/gemma + ${{ github.workspace }}/build/libgemma.a diff --git a/BUILD.bazel b/BUILD.bazel index 190690b..18dad30 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -114,54 +114,3 @@ cc_binary( "//:thread_pool", ], ) - -# copybara:strip_begin -cc_binary( - name = "run_csv", - srcs = [ - "run_csv.cc", - ], - deps = [ - ":app", - ":args", - ":gemma_lib", - "//compression:compress", - # copybara:import_next_line:hwy - "//:hwy", - # copybara:import_next_line:hwy - "//:nanobenchmark", - # copybara:import_next_line:hwy - "//:profiler", - # copybara:import_next_line:hwy - "//:thread_pool", - "//third_party/riegeli/bytes:file_reader", - "//third_party/riegeli/bytes:file_writer", - "//third_party/riegeli/csv:csv_reader", - "//third_party/riegeli/csv:csv_writer", - ], -) - -gensignature( - name = "gemma_sign", - srcs = [":gemma"], -) - -cc_test( - name = "benchmarks", - size = "large", - srcs = [ - "benchmarks.cc", - ], - tags = ["notap"], - deps = [ - ":app", - ":gemma_lib", - "//third_party/benchmark", - # copybara:import_next_line:hwy - "//:hwy", - # copybara:import_next_line:hwy - "//:thread_pool", - ], -) - -# copybara:strip_end diff --git a/CMakeLists.txt b/CMakeLists.txt index 3858968..308e258 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) FetchContent_MakeAvailable(highway) -## Note: absl meeds tp be installed by sentencepiece. This will only happen if +## Note: absl needs to be installed by sentencepiece. This will only happen if ## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) @@ -43,14 +43,13 @@ set(SOURCES util/args.h ) -add_compile_options($<$:-O2>) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() # Allowable types for WEIGHT_TYPE: # float - slow, not recommended -# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway +# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway # SfpStream - 8-bit switched floating point (recommended) # NuqStream - experimental, work-in-progress option(WEIGHT_TYPE "Set weight type" "") @@ -68,12 +67,17 @@ target_link_libraries(gemma hwy hwy_contrib sentencepiece) target_include_directories(gemma PRIVATE ./) FetchContent_GetProperties(sentencepiece) target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(gemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(gemma PRIVATE $<$:-Wno-deprecated-declarations>) ## Library Target add_library(libgemma ${SOURCES}) set_property(TARGET libgemma PROPERTY CXX_STANDARD 17) set_target_properties(libgemma PROPERTIES PREFIX "") +set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(libgemma PUBLIC ./) target_link_libraries(libgemma hwy hwy_contrib sentencepiece) target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 0000000..5fe13c8 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,59 @@ +{ + "version": 3, + "cmakeMinimumRequired": { + "major": 3, + "minor": 11, + "patch": 0 + }, + "configurePresets": [ + { + "name": "__defaults__", + "hidden": true, + "binaryDir": "${sourceDir}/build" + }, + { + "name": "make", + "inherits": "__defaults__", + "displayName": "Make", + "description": "Unix Makefiles", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build" + }, + { + "name": "windows", + "inherits": "__defaults__", + "displayName": "Windows", + "description": "Visual Studio 2022 with Clang/LLVM frontend", + "generator": "Visual Studio 17 2022", + "toolset": "ClangCL", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + } + ], + "buildPresets": [ + { + "name": "__defaults__", + "hidden": true, + "targets": [ + "gemma", + "libgemma" + ] + }, + { + "name": "make", + "inherits": "__defaults__", + "displayName": "Unix Makefiles", + "configurePreset": "make" + }, + { + "name": "windows", + "inherits": "__defaults__", + "displayName": "Windows", + "configuration": "Release", + "configurePreset": "windows" + } + ] + } diff --git a/DEVELOPERS.md b/DEVELOPERS.md index d06b0f8..f670c49 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -70,3 +70,33 @@ The implementation code is roughly split into 4 layers, from high to low level: 4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of highway) supporting the implementations in (3). + +Besides these layers, supporting utilities are: + +- `compression/` - model compression operations. The 8-bit switched floating + point model conversion is here. +- `util/` - command line argument handling and any other utilities. + +## Style and Formatting + +A `.clang-format` configuration is provided with our defaults, please run source +files through `clang-format` (or a formatter that produces equivalent behavior) +before finalizing PR for submission. + +## Compile-Time Flags (Advanced) + +There are several compile-time flags to be aware of (note these may or may not +be exposed to the build system): + +- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as + WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream` + (default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to + enable for higher-fidelity (but slower) bfloat16 support. This is defined in + `gemma.h`. +- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV + Cache. The default is 4096 tokens but can be overridden. This is not exposed + through `CMakeLists.txt` yet. + +In the medium term both of these will likely be deprecated in favor of handling +options at runtime - allowing for multiple weight compression schemes in a single +build and dynamically resizes the KV cache as needed. diff --git a/README.md b/README.md index e278833..331d96f 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,16 @@ Before starting, you should have installed: least C++17. - `tar` for extracting archives from Kaggle. +Building natively on Windows requires the Visual Studio 2012 Build Tools with the +optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the +command line with +[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/): + +```sh +winget install --id Kitware.CMake +winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset" +``` + ### Step 1: Obtain model weights and tokenizer from Kaggle Visit [the Gemma model page on @@ -82,7 +92,7 @@ weights enable faster inference. In general, we recommend starting with the | `7b-pt` | 7 billion parameter pre-trained model, bfloat16 | | `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point | -> [!NOTE] +> [!NOTE] > **Important**: We strongly recommend starting off with the `2b-it-sfp` model to > get up and running. @@ -104,9 +114,14 @@ convenient directory location (e.g. the `build/` directory in this repo). The build system uses [CMake](https://cmake.org/). To build the gemma inference runtime, create a build directory and generate the build files using `cmake` -from the top-level project directory. For the 8-bit switched floating point -weights (sfp), run cmake with no options: +from the top-level project directory. Note if you previous ran `cmake` and are +re-running with a different setting, be sure to clean out the `build/` directory +with `rm -rf build/*` (warning this will delete any other files in the `build/` +directory. +For the 8-bit switched floating point weights (sfp), run cmake with no options: + +#### Unix-like Platforms ```sh cmake -B build ``` @@ -126,17 +141,18 @@ your weights, you can enter the `build/` directory and run `make` to build the `./gemma` executable: ```sh -cd build -make -j [number of parallel threads to use] gemma +# Configure `build` directory +cmake --preset make + +# Build project using make +cmake --build --preset make -j [number of parallel threads to use] ``` Replace `[number of parallel threads to use]` with a number - the number of -cores available on your system is a reasonable heuristic. - -For example, `make -j4 gemma` will build using 4 threads. If this is successful, -you should now have a `gemma` executable in the `build/` directory. If the -`nproc` command is available, you can use `make -j$(nproc) gemma` as a -reasonable default for the number of threads. +cores available on your system is a reasonable heuristic. For example, +`make -j4 gemma` will build using 4 threads. If the `nproc` command is +available, you can use `make -j$(nproc) gemma` as a reasonable default +for the number of threads. If you aren't sure of the right value for the `-j` flag, you can simply run `make gemma` instead and it should still build the `./gemma` executable. @@ -145,6 +161,20 @@ If you aren't sure of the right value for the `-j` flag, you can simply run > On Windows Subsystem for Linux (WSL) users should set the number of > parallel threads to 1. Using a larger number may result in errors. +If the build is successful, you should now have a `gemma` executable in the `build/` directory. + +#### Windows + +```sh +# Configure `build` directory +cmake --preset windows + +# Build project using Visual Studio Build Tools +cmake --build --preset windows -j [number of parallel threads to use] +``` + +If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory. + ### Step 4: Run You can now run `gemma` from inside the `build/` directory. @@ -212,6 +242,21 @@ We're working on a python script to convert a standard model format to `.sbs`, and hope have it available in the next week or so. Follow [this issue](https://github.com/google/gemma.cpp/issues/11) for updates. +**What are some easy ways to make the model run faster?** + +1. Make sure you are using the 8-bit switched floating point `-sfp` models. +2. If you're on a laptop, make sure power mode is set to maximize performance +and saving mode is **off**. For most laptops, the power saving modes get +activated automatically if the computer is not plugged in. +3. Close other unused cpu-intensive applications. +4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance +cores get engaged. +5. Experiment with the `--num_threads` argument value. Depending on the device, +larger numbers don't always mean better performance. + +We're also working on algorithmic and optimization approaches for faster +inference, stay tuned. + ## Usage `gemma` has different usage modes, controlled by the verbosity flag. @@ -247,7 +292,7 @@ max_tokens : 3072 max_generated_tokens : 2048 *Usage* - Enter an instruction and press enter (%Q quits). + Enter an instruction and press enter (%C reset conversation, %Q quits). *Examples* - Write an email to grandma thanking her for the cookies. @@ -385,6 +430,17 @@ make -j [number of parallel threads to use] libgemma If this is successful, you should now have a `libgemma` library file in the `build/` directory. On Unix platforms, the filename is `libgemma.a`. +## Independent Projects Using gemma.cpp + +Some independent projects using gemma.cpp: + +- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python) +- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma) +- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project) + +If you would like to have your project included, feel free to get in touch or +submit a PR with a `README.md` edit. + ## Acknowledgements and Contacts gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com) diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 8d6c1d0..e088fc6 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -13,14 +13,32 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Request POSIX 2008, including `pread()` and `posix_fadvise()`. +#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 +#undef _XOPEN_SOURCE +#define _XOPEN_SOURCE 700 +#endif +#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809 +#define _POSIX_C_SOURCE 200809 +#endif + +// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c. +#undef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 + // copybara:import_next_line:gemma_cpp #include "compression/blob_store.h" -#include // open #include #include // SEEK_END - unistd isn't enough for IDE. #include // O_RDONLY -#include // read, close +#include // open +#if HWY_OS_WIN +#include // read, write, close +#include +#else +#include // read, write, close +#endif #include #include @@ -30,6 +48,54 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_compiler_arch.h" +namespace { +#if HWY_OS_WIN + +// pread is not supported on Windows +static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) { + HANDLE file = reinterpret_cast(_get_osfhandle(fd)); + if (file == INVALID_HANDLE_VALUE) { + return -1; + } + + OVERLAPPED overlapped = {0}; + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + DWORD bytes_read; + if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return -1; + } + } + + return bytes_read; +} + +// pwrite is not supported on Windows +static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) { + HANDLE file = reinterpret_cast(_get_osfhandle(fd)); + if (file == INVALID_HANDLE_VALUE) { + return -1; + } + + OVERLAPPED overlapped = {0}; + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + DWORD bytes_written; + if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return -1; + } + } + + return bytes_written; +} + +#endif +} // namespace + namespace gcpp { hwy::uint128_t MakeKey(const char* string) { @@ -64,19 +130,31 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, } } + struct IO { // Returns size in bytes or 0. static uint64_t FileSize(const char* filename) { int fd = open(filename, O_RDONLY); - if (fd >= 0) { - const off_t size = lseek(fd, 0, SEEK_END); - HWY_ASSERT(close(fd) != -1); - if (size != static_cast(-1)) { - return static_cast(size); - } + if (fd < 0) { + return 0; } - return 0; +#if HWY_OS_WIN + const int64_t size = _lseeki64(fd, 0, SEEK_END); + HWY_ASSERT(close(fd) != -1); + if (size < 0) { + return 0; + } +#else + static_assert(sizeof(off_t) == 8, "64-bit off_t required"); + const off_t size = lseek(fd, 0, SEEK_END); + HWY_ASSERT(close(fd) != -1); + if (size == static_cast(-1)) { + return 0; + } +#endif + + return static_cast(size); } static bool Read(int fd, uint64_t offset, uint64_t size, void* to) { @@ -252,10 +330,18 @@ class BlobStore { #pragma pack(pop) BlobError BlobReader::Open(const char* filename) { +#if HWY_OS_WIN + DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN; + HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr, + OPEN_EXISTING, flags, nullptr); + if (file == INVALID_HANDLE_VALUE) return __LINE__; + fd_ = _open_osfhandle(reinterpret_cast(file), _O_RDONLY); +#else fd_ = open(filename, O_RDONLY); +#endif if (fd_ < 0) return __LINE__; -#if _POSIX_C_SOURCE >= 200112L +#if HWY_OS_LINUX // Doubles the readahead window, which seems slightly faster when cached. (void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL); #endif @@ -330,7 +416,15 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, keys_.data(), blobs_.data(), keys_.size()); // Create/replace existing file. +#if HWY_OS_WIN + DWORD flags = FILE_ATTRIBUTE_NORMAL; + HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, + flags, nullptr); + if (file == INVALID_HANDLE_VALUE) return __LINE__; + const int fd = _open_osfhandle(reinterpret_cast(file), _O_WRONLY); +#else const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644); +#endif if (fd < 0) return __LINE__; std::atomic_flag err = ATOMIC_FLAG_INIT; @@ -341,6 +435,7 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, err.test_and_set(); } }); + HWY_ASSERT(close(fd) != -1); if (err.test_and_set()) return __LINE__; return 0; } diff --git a/compression/distortion.h b/compression/distortion.h index 8c0742a..5fd778f 100644 --- a/compression/distortion.h +++ b/compression/distortion.h @@ -25,6 +25,8 @@ namespace gcpp { class DistortionStats { public: void Notify(float original, float distorted) { + (void)padding_; // prevent unused member warning + const double l1 = hwy::ScalarAbs(original - distorted); if (l1 > max_l1_) { diff --git a/compression/stats.cc b/compression/stats.cc index 2013422..8e66119 100644 --- a/compression/stats.cc +++ b/compression/stats.cc @@ -114,7 +114,7 @@ std::string Stats::ToString(int exclude) const { pos += ret; } - HWY_ASSERT(pos < sizeof(buf)); + HWY_ASSERT(pos < static_cast(sizeof(buf))); return buf; } diff --git a/configs.h b/configs.h index ebe6220..bf25596 100644 --- a/configs.h +++ b/configs.h @@ -18,11 +18,16 @@ #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +// Allow changing pre-allocated kv cache size as a compiler flag +#ifndef GEMMA_MAX_SEQLEN +#define GEMMA_MAX_SEQLEN 4096 +#endif // !GEMMA_MAX_SEQLEN + #include namespace gcpp { -static constexpr size_t kSeqLen = 7168; +static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; @@ -31,8 +36,8 @@ struct ConfigGemma7B { static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA - static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kKVHeads = 16; // standard MHA + static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = 1; }; @@ -43,8 +48,8 @@ struct ConfigGemma2B { static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; - static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support - static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support + static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = 1; }; diff --git a/gemma.cc b/gemma.cc index 70777ac..4775f89 100644 --- a/gemma.cc +++ b/gemma.cc @@ -633,30 +633,32 @@ void ForEachTensor(const Weights* weights, c_weights.c_final_norm_scale); char name[16]; - for (size_t layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - Layer* layer = weights ? &weights->layers[layer_idx] : nullptr; - CompressedLayer* c_layer = c_weights.CLayer(layer_idx); + for (int layer_idx = 0; layer_idx < static_cast(TConfig::kLayers); + ++layer_idx) { + const size_t idx = static_cast(layer_idx); + Layer* layer = weights ? &weights->layers[idx] : nullptr; + CompressedLayer* c_layer = c_weights.CLayer(idx); - snprintf(name, sizeof(name), "pre_ff_ns_%lu", layer_idx); + snprintf(name, sizeof(name), "pre_ff_ns_%d", layer_idx); func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr, c_layer->c_pre_ffw_norm_scale); - snprintf(name, sizeof(name), "gating_ein_%lu", layer_idx); + snprintf(name, sizeof(name), "gating_ein_%d", layer_idx); func(name, layer ? layer->gating_einsum_w.data() : nullptr, c_layer->c_gating_einsum_w); - snprintf(name, sizeof(name), "linear_w_%lu", layer_idx); + snprintf(name, sizeof(name), "linear_w_%d", layer_idx); func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w); - snprintf(name, sizeof(name), "qkv_ein_%lu", layer_idx); + snprintf(name, sizeof(name), "qkv_ein_%d", layer_idx); func(name, layer ? layer->qkv_einsum_w.data() : nullptr, c_layer->c_qkv_einsum_w); - snprintf(name, sizeof(name), "att_ein_%lu", layer_idx); + snprintf(name, sizeof(name), "att_ein_%d", layer_idx); func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr, c_layer->c_attn_vec_einsum_w); - snprintf(name, sizeof(name), "pre_att_ns_%lu", layer_idx); + snprintf(name, sizeof(name), "pre_att_ns_%d", layer_idx); func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr, c_layer->c_pre_attention_norm_scale); } diff --git a/gemma.h b/gemma.h index 5dc9f62..7195bc9 100644 --- a/gemma.h +++ b/gemma.h @@ -26,15 +26,19 @@ // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream +// copybara:end // copybara:import_next_line:gemma_cpp -#include "configs.h" // kSeqLen +#include "configs.h" // kSeqLen +// copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase +#include "util/args.h" // ArgsBase +// copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" +// copybara:end namespace gcpp { @@ -118,21 +122,22 @@ struct LoaderArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { visitor(tokenizer, "tokenizer", Path(), - "Path name of tokenizer model file. (required)"); + "Path name of tokenizer model file.\n Required argument."); visitor( cache, "compressed_weights", Path(), "Path name of compressed weights file, regenerated from `--weights` " "file if " - "the compressed weights file does not exist. (required)"); + "the compressed weights file does not exist.\n Required argument."); visitor(model_type, "model", std::string(), - "Model type - can be 2b-it (2B parameters, instruction-tuned), " - "2b-pt (2B parameters, pretrained), 7b-it (7B parameters, " - "instruction-tuned), or 7b-pt (7B parameters, pretrained). " - "(required)"); + "Model type\n 2b-it (2B parameters, instruction-tuned)\n " + "2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " + "instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" + " Required argument."); visitor(model, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " "compressed_weights file is not present and needs to be " - "regenerated. Otherwise, not needed"); + "regenerated. This parameter is only required for compressing" + "new model weight exports, otherwise it is not needed."); } }; @@ -186,10 +191,10 @@ struct InferenceArgs : public ArgsBase { visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); - visitor(multiturn, "multiturn", true, + visitor(multiturn, "multiturn", false, "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)", - 2); + "interaction without quitting)\n Default : 0 (conversation " + "resets every turn)"); } }; diff --git a/ops.h b/ops.h index db2ae4f..7619b44 100644 --- a/ops.h +++ b/ops.h @@ -214,7 +214,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, size_t size) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - hn::Transform(D(), x, size, [](D d, hn::Vec v) { return Gelu(d, v); }); + hn::Transform(D(), x, size, + [](D d, hn::Vec v) HWY_ATTR { return Gelu(d, v); }); } // out[i] = BF(mul[i] * Gelu(gelu_in[i])) @@ -567,22 +568,41 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; - using V = hn::Vec; + const size_t N = hn::Lanes(d); - // Find max so we can subtract it below. - const V vmin = hn::Set(d, hwy::LowestValue()); - V max = vmin; - hn::Foreach(d, x, mask_pos, vmin, - [&max](D d, V v) { max = hn::Max(max, v); }); - max = hn::MaxOfLanes(d, max); // broadcast + // Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors + // cannot be lambda-captured. + // TODO(janwas): could be replaced with an hn::Accumulate algo. + const hn::Vec vmin = hn::Set(d, hwy::LowestValue()); + hn::Vec vmax = vmin; + size_t idx = 0; + if (mask_pos >= N) { + for (; idx <= mask_pos - N; idx += N) { + vmax = hn::Max(vmax, LoadU(d, x + idx)); + } + } + vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx)); + vmax = hn::MaxOfLanes(d, vmax); // broadcast // Subtract max (avoid precision loss for large exponents) and exponentiate. - V sum = hn::Zero(d); - hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) { - const V out = hn::Exp(d, hn::Sub(v, max)); + // Also avoid hn::Transform because the additional `sum` output vector cannot + // be captured by a lambda. + hn::Vec sum = hn::Zero(d); + idx = 0; + if (mask_pos >= N) { + for (; idx <= mask_pos - N; idx += N) { + const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax)); + sum = hn::Add(sum, out); + hn::StoreU(out, d, x + idx); + } + } + if (mask_pos > idx) { + const size_t remaining = mask_pos - idx; + const hn::Vec out = + hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax)); sum = hn::Add(sum, out); - return out; - }); + hn::StoreN(out, d, x + idx, remaining); + } // Normalize to probability distribution const float mul = 1.0f / hn::ReduceSum(d, sum); @@ -601,13 +621,12 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; - using V = hn::Vec; - const V inv_cap = hn::Set(d, 1.0f / cap); - const V vcap = hn::Set(d, cap); + const float inv_cap = 1.0f / cap; - hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec v) { - return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v))); + hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec v) HWY_ATTR { + return hn::Mul(hn::Set(d, cap), + hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); }); } diff --git a/run.cc b/run.cc index 96ba316..507979d 100644 --- a/run.cc +++ b/run.cc @@ -24,12 +24,16 @@ // copybara:import_next_line:gemma_cpp #include "compression/compress.h" +// copybara:end // copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma +#include "gemma.h" // Gemma +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/app.h" +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp +// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -39,20 +43,13 @@ namespace gcpp { -void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, - gcpp::AppArgs& app) { - fprintf(stderr, - "\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to " - "specify 3 required model loading arguments: --tokenizer, " - "--compressed_weights, " - "and --model.\n\nModel Loading Arguments\n\n"); - loader.Help(); - fprintf(stderr, "\nInference Arguments\n\n"); - inference.Help(); - fprintf(stderr, "\nApplication Arguments\n\n"); - app.Help(); - fprintf(stderr, "\n\n"); -} +static constexpr std::string_view kAsciiArtBanner = + " __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n" + " / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n" + "| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n" + " \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n" + " __/ | | | | |\n" + " |___/ |_| |_|"; void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { loader.Print(app.verbosity); @@ -69,7 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { << std::thread::hardware_concurrency() << std::endl << "Instruction set : " << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" << "\n" + << hwy::VectorBytes() * 8 << " bits)" + << "\n" << "Weight Type : " << gcpp::TypeName(gcpp::WeightT()) << "\n" << "EmbedderInput Type : " @@ -77,9 +75,31 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } } +void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, + gcpp::AppArgs& app) { + std::cerr + << kAsciiArtBanner + << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "To run gemma.cpp, you need to " + "specify 3 required model loading arguments:\n --tokenizer\n " + "--compressed_weights\n" + " --model.\n"; + std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--compressed_weights 2b-it-sfp.sbs --model 2b-it\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n*Application Arguments*\n\n"; + app.Help(); + std::cerr << "\n"; +} + void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token) { + int verbosity, const gcpp::AcceptFunc& accept_token, + std::string& eot_line) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -137,7 +157,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, if (verbosity >= 1) { std::cout << "> " << std::flush; } - std::getline(std::cin, prompt_string); + + if (eot_line.size() == 0) { + std::getline(std::cin, prompt_string); + } else { + std::string line; + while (std::getline(std::cin, line)) { + if (line == eot_line) { + break; + } + prompt_string += line + "\n"; + } + } } if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { @@ -221,7 +252,12 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { const std::string instructions = "*Usage*\n" - " Enter an instruction and press enter (%Q quits).\n\n" + " Enter an instruction and press enter (%C resets conversation, " + "%Q quits).\n" + + (inference.multiturn == 0 + ? std::string(" Since multiturn is set to 0, conversation will " + "automatically reset every turn.\n\n") + : "\n") + "*Examples*\n" " - Write an email to grandma thanking her for the cookies.\n" " - What are some historical attractions to visit around " @@ -230,13 +266,14 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { " - Write a standup comedy bit about GPU programming.\n"; std::cout << "\033[2J\033[1;1H" // clear screen - << banner_ascii_art << "\n\n"; + << kAsciiArtBanner << "\n\n"; ShowConfig(loader, inference, app); std::cout << "\n" << instructions << "\n"; } - ReplGemma(model, pool, inner_pool, inference, app.verbosity, - /*accept_token=*/[](int) { return true; }); + ReplGemma( + model, pool, inner_pool, inference, app.verbosity, + /*accept_token=*/[](int) { return true; }, app.eot_line); } } // namespace gcpp diff --git a/util/app.h b/util/app.h index 966fa41..7f926a5 100644 --- a/util/app.h +++ b/util/app.h @@ -18,14 +18,20 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#if HWY_OS_LINUX #include + +#include // IDE does not recognize errno.h as providing errno. +#endif #include +#include #include // std::clamp #include // NOLINT> // copybara:import_next_line:gemma_cpp #include "util/args.h" +// copybara:end #include "hwy/base.h" // HWY_ASSERT namespace gcpp { @@ -36,7 +42,13 @@ static inline void PinThreadToCore(size_t cpu_index) { cpu_set_t cset; // bit array CPU_ZERO(&cset); // clear all CPU_SET(cpu_index, &cset); // set bit indicating which processor to run on. - HWY_ASSERT(0 == sched_setaffinity(0, sizeof(cset), &cset)); + const int err = sched_setaffinity(0, sizeof(cset), &cset); + if (err != 0) { + fprintf(stderr, + "sched_setaffinity returned %d, errno %d. Can happen if running in " + "a container; this warning is safe to ignore.\n", + err, errno); + } #else (void)cpu_index; #endif @@ -62,10 +74,10 @@ class AppArgs : public ArgsBase { Path log; // output int verbosity; size_t num_threads; + std::string eot_line; template void ForEach(const Visitor& visitor) { - visitor(log, "log", Path{"/tmp/log.txt"}, "Logging file", 2); visitor(verbosity, "verbosity", 1, "Show verbose developer information\n 0 = only print generation " "output\n 1 = standard user-facing terminal ui\n 2 = show " @@ -73,10 +85,16 @@ class AppArgs : public ArgsBase { 2); visitor(num_threads, "num_threads", kDefaultNumThreads, // see ChooseNumThreads - "Number of threads to use. Default value is set based on an " - "estimate of " - "how many concurrent threads are supported.", + "Number of threads to use.\n Default = Estimate of the " + "number of suupported concurrent threads.", 2); + visitor( + eot_line, "eot_line", std::string(""), + "End of turn line. " + "When you specify this, the prompt will be all lines " + "before the line where only the given string appears.\n Default = " + "When a newline is encountered, that signals the end of the turn.", + 2); } }; diff --git a/util/args.h b/util/args.h index ce03ef2..b9ab985 100644 --- a/util/args.h +++ b/util/args.h @@ -204,7 +204,7 @@ class ArgsBase { } }; -static bool HasHelp(int argc, char* argv[]) { +static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) { // TODO(austinvhuang): handle case insensitivity if (argc == 1) { // no arguments - print help