Merge branch 'google:main' into main

This commit is contained in:
Sascha Ronnie Daoudia 2024-02-29 20:50:31 +01:00 committed by GitHub
commit 5d72c911fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 465 additions and 172 deletions

1
.clang-format Normal file
View File

@ -0,0 +1 @@
BasedOnStyle: Google

View File

@ -1,22 +1,30 @@
name: Build
name: build
# Trigger on push or via manual dispatch.
on: [push, workflow_dispatch]
# Trigger on push, pull request, or via manual dispatch.
on: [push, pull_request, workflow_dispatch]
jobs:
build:
runs-on: ${{matrix.os}}
name: ${{ matrix.os }} ${{ matrix.type }}
runs-on: ${{ matrix.os }}
name: ${{ matrix.os }} (${{ matrix.preset }}) ${{ matrix.build_type }}
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
type: ['Release']
os: ['ubuntu-latest']
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
build_type: ['Release']
preset: ['make', 'windows']
exclude:
- os: ubuntu-latest
preset: windows
- os: macos-latest
preset: windows
- os: windows-latest
preset: make
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.preset }}-${{ matrix.build_type }}
cancel-in-progress: true
steps:
@ -26,20 +34,23 @@ jobs:
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
# Install CMake
- uses: lukka/get-cmake@latest
- name: Configure CMake
run: >
cmake --preset ${{ matrix.preset }}
-S ${{ github.workspace }} -B ${{ github.workspace }}/build
-D CMAKE_BUILD_TYPE=${{ matrix.build_type }}
-D CMAKE_C_COMPILER_LAUNCHER=ccache
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
# Build via CMake
# Reference: https://github.com/lukka/run-cmake/blob/v3/action.yml
- name: Build via cmake
uses: lukka/run-cmake@v3
- name: Build
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }}
- name: Archive production artifacts
uses: actions/upload-artifact@v4
with:
cmakeListsOrSettingsJson: CMakeListsTxtAdvanced
cmakeAppendedArgs: >
-D CMAKE_C_COMPILER_LAUNCHER=ccache
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
buildWithCMake: true
# Explicitly list build targets here.
# Building "all" includes test executables and takes much longer.
buildWithCMakeArgs: "-- gemma"
buildDirectory: '${{ github.workspace }}/build'
name: gemma-${{ matrix.os }}-${{ matrix.preset }}-${{ matrix.build_type }}
path: |
${{ github.workspace }}/build/${{ matrix.build_type }}/gemma.exe
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
${{ github.workspace }}/build/gemma
${{ github.workspace }}/build/libgemma.a

View File

@ -114,54 +114,3 @@ cc_binary(
"//:thread_pool",
],
)
# copybara:strip_begin
cc_binary(
name = "run_csv",
srcs = [
"run_csv.cc",
],
deps = [
":app",
":args",
":gemma_lib",
"//compression:compress",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:nanobenchmark",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
"//third_party/riegeli/bytes:file_reader",
"//third_party/riegeli/bytes:file_writer",
"//third_party/riegeli/csv:csv_reader",
"//third_party/riegeli/csv:csv_writer",
],
)
gensignature(
name = "gemma_sign",
srcs = [":gemma"],
)
cc_test(
name = "benchmarks",
size = "large",
srcs = [
"benchmarks.cc",
],
tags = ["notap"],
deps = [
":app",
":gemma_lib",
"//third_party/benchmark",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:thread_pool",
],
)
# copybara:strip_end

View File

@ -24,7 +24,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)
## Note: absl meeds tp be installed by sentencepiece. This will only happen if
## Note: absl needs to be installed by sentencepiece. This will only happen if
## cmake is invoked with -DSPM_ENABLE_SHARED=OFF and -DSPM_ABSL_PROVIDER=module
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
@ -43,14 +43,13 @@ set(SOURCES
util/args.h
)
add_compile_options($<$<CONFIG:Release>:-O2>)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
# Allowable types for WEIGHT_TYPE:
# float - slow, not recommended
# hwy::bfloat16_t - bfloat16 as impemented by https://github.com/google/highway
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
# SfpStream - 8-bit switched floating point (recommended)
# NuqStream - experimental, work-in-progress
option(WEIGHT_TYPE "Set weight type" "")
@ -68,12 +67,17 @@ target_link_libraries(gemma hwy hwy_contrib sentencepiece)
target_include_directories(gemma PRIVATE ./)
FetchContent_GetProperties(sentencepiece)
target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(gemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
## Library Target
add_library(libgemma ${SOURCES})
set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
set_target_properties(libgemma PROPERTIES PREFIX "")
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(libgemma PUBLIC ./)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)

59
CMakePresets.json Normal file
View File

@ -0,0 +1,59 @@
{
"version": 3,
"cmakeMinimumRequired": {
"major": 3,
"minor": 11,
"patch": 0
},
"configurePresets": [
{
"name": "__defaults__",
"hidden": true,
"binaryDir": "${sourceDir}/build"
},
{
"name": "make",
"inherits": "__defaults__",
"displayName": "Make",
"description": "Unix Makefiles",
"generator": "Unix Makefiles",
"binaryDir": "${sourceDir}/build"
},
{
"name": "windows",
"inherits": "__defaults__",
"displayName": "Windows",
"description": "Visual Studio 2022 with Clang/LLVM frontend",
"generator": "Visual Studio 17 2022",
"toolset": "ClangCL",
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Windows"
}
}
],
"buildPresets": [
{
"name": "__defaults__",
"hidden": true,
"targets": [
"gemma",
"libgemma"
]
},
{
"name": "make",
"inherits": "__defaults__",
"displayName": "Unix Makefiles",
"configurePreset": "make"
},
{
"name": "windows",
"inherits": "__defaults__",
"displayName": "Windows",
"configuration": "Release",
"configurePreset": "windows"
}
]
}

View File

@ -70,3 +70,33 @@ The implementation code is roughly split into 4 layers, from high to low level:
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
highway) supporting the implementations in (3).
Besides these layers, supporting utilities are:
- `compression/` - model compression operations. The 8-bit switched floating
point model conversion is here.
- `util/` - command line argument handling and any other utilities.
## Style and Formatting
A `.clang-format` configuration is provided with our defaults, please run source
files through `clang-format` (or a formatter that produces equivalent behavior)
before finalizing PR for submission.
## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system):
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
`gemma.h`.
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet.
In the medium term both of these will likely be deprecated in favor of handling
options at runtime - allowing for multiple weight compression schemes in a single
build and dynamically resizes the KV cache as needed.

View File

@ -55,6 +55,16 @@ Before starting, you should have installed:
least C++17.
- `tar` for extracting archives from Kaggle.
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
command line with
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
```sh
winget install --id Kitware.CMake
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
```
### Step 1: Obtain model weights and tokenizer from Kaggle
Visit [the Gemma model page on
@ -82,7 +92,7 @@ weights enable faster inference. In general, we recommend starting with the
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
> [!NOTE]
> [!NOTE]
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
> get up and running.
@ -104,9 +114,14 @@ convenient directory location (e.g. the `build/` directory in this repo).
The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory. For the 8-bit switched floating point
weights (sfp), run cmake with no options:
from the top-level project directory. Note if you previous ran `cmake` and are
re-running with a different setting, be sure to clean out the `build/` directory
with `rm -rf build/*` (warning this will delete any other files in the `build/`
directory.
For the 8-bit switched floating point weights (sfp), run cmake with no options:
#### Unix-like Platforms
```sh
cmake -B build
```
@ -126,17 +141,18 @@ your weights, you can enter the `build/` directory and run `make` to build the
`./gemma` executable:
```sh
cd build
make -j [number of parallel threads to use] gemma
# Configure `build` directory
cmake --preset make
# Build project using make
cmake --build --preset make -j [number of parallel threads to use]
```
Replace `[number of parallel threads to use]` with a number - the number of
cores available on your system is a reasonable heuristic.
For example, `make -j4 gemma` will build using 4 threads. If this is successful,
you should now have a `gemma` executable in the `build/` directory. If the
`nproc` command is available, you can use `make -j$(nproc) gemma` as a
reasonable default for the number of threads.
cores available on your system is a reasonable heuristic. For example,
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
available, you can use `make -j$(nproc) gemma` as a reasonable default
for the number of threads.
If you aren't sure of the right value for the `-j` flag, you can simply run
`make gemma` instead and it should still build the `./gemma` executable.
@ -145,6 +161,20 @@ If you aren't sure of the right value for the `-j` flag, you can simply run
> On Windows Subsystem for Linux (WSL) users should set the number of
> parallel threads to 1. Using a larger number may result in errors.
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
#### Windows
```sh
# Configure `build` directory
cmake --preset windows
# Build project using Visual Studio Build Tools
cmake --build --preset windows -j [number of parallel threads to use]
```
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.
@ -212,6 +242,21 @@ We're working on a python script to convert a standard model format to `.sbs`,
and hope have it available in the next week or so. Follow [this
issue](https://github.com/google/gemma.cpp/issues/11) for updates.
**What are some easy ways to make the model run faster?**
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
2. If you're on a laptop, make sure power mode is set to maximize performance
and saving mode is **off**. For most laptops, the power saving modes get
activated automatically if the computer is not plugged in.
3. Close other unused cpu-intensive applications.
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
cores get engaged.
5. Experiment with the `--num_threads` argument value. Depending on the device,
larger numbers don't always mean better performance.
We're also working on algorithmic and optimization approaches for faster
inference, stay tuned.
## Usage
`gemma` has different usage modes, controlled by the verbosity flag.
@ -247,7 +292,7 @@ max_tokens : 3072
max_generated_tokens : 2048
*Usage*
Enter an instruction and press enter (%Q quits).
Enter an instruction and press enter (%C reset conversation, %Q quits).
*Examples*
- Write an email to grandma thanking her for the cookies.
@ -385,6 +430,17 @@ make -j [number of parallel threads to use] libgemma
If this is successful, you should now have a `libgemma` library file in the
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
## Independent Projects Using gemma.cpp
Some independent projects using gemma.cpp:
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
If you would like to have your project included, feel free to get in touch or
submit a PR with a `README.md` edit.
## Acknowledgements and Contacts
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)

View File

@ -13,14 +13,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
// copybara:import_next_line:gemma_cpp
#include "compression/blob_store.h"
#include <fcntl.h> // open
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, close
#include <fcntl.h> // open
#if HWY_OS_WIN
#include <io.h> // read, write, close
#include <fileapi.h>
#else
#include <unistd.h> // read, write, close
#endif
#include <atomic>
#include <vector>
@ -30,6 +48,54 @@
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
namespace {
#if HWY_OS_WIN
// pread is not supported on Windows
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_read;
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_read;
}
// pwrite is not supported on Windows
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_written;
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_written;
}
#endif
} // namespace
namespace gcpp {
hwy::uint128_t MakeKey(const char* string) {
@ -64,19 +130,31 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
}
}
struct IO {
// Returns size in bytes or 0.
static uint64_t FileSize(const char* filename) {
int fd = open(filename, O_RDONLY);
if (fd >= 0) {
const off_t size = lseek(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size != static_cast<off_t>(-1)) {
return static_cast<uint64_t>(size);
}
if (fd < 0) {
return 0;
}
return 0;
#if HWY_OS_WIN
const int64_t size = _lseeki64(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size < 0) {
return 0;
}
#else
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size == static_cast<off_t>(-1)) {
return 0;
}
#endif
return static_cast<uint64_t>(size);
}
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
@ -252,10 +330,18 @@ class BlobStore {
#pragma pack(pop)
BlobError BlobReader::Open(const char* filename) {
#if HWY_OS_WIN
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
OPEN_EXISTING, flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
#else
fd_ = open(filename, O_RDONLY);
#endif
if (fd_ < 0) return __LINE__;
#if _POSIX_C_SOURCE >= 200112L
#if HWY_OS_LINUX
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
#endif
@ -330,7 +416,15 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
keys_.data(), blobs_.data(), keys_.size());
// Create/replace existing file.
#if HWY_OS_WIN
DWORD flags = FILE_ATTRIBUTE_NORMAL;
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS,
flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
#else
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
#endif
if (fd < 0) return __LINE__;
std::atomic_flag err = ATOMIC_FLAG_INIT;
@ -341,6 +435,7 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
err.test_and_set();
}
});
HWY_ASSERT(close(fd) != -1);
if (err.test_and_set()) return __LINE__;
return 0;
}

View File

@ -25,6 +25,8 @@ namespace gcpp {
class DistortionStats {
public:
void Notify(float original, float distorted) {
(void)padding_; // prevent unused member warning
const double l1 = hwy::ScalarAbs(original - distorted);
if (l1 > max_l1_) {

View File

@ -114,7 +114,7 @@ std::string Stats::ToString(int exclude) const {
pos += ret;
}
HWY_ASSERT(pos < sizeof(buf));
HWY_ASSERT(pos < static_cast<int>(sizeof(buf)));
return buf;
}

View File

@ -18,11 +18,16 @@
#ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096
#endif // !GEMMA_MAX_SEQLEN
#include <stddef.h>
namespace gcpp {
static constexpr size_t kSeqLen = 7168;
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
struct ConfigGemma7B {
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -31,8 +36,8 @@ struct ConfigGemma7B {
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1;
};
@ -43,8 +48,8 @@ struct ConfigGemma2B {
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1;
};

View File

@ -633,30 +633,32 @@ void ForEachTensor(const Weights<TConfig>* weights,
c_weights.c_final_norm_scale);
char name[16];
for (size_t layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
Layer<TConfig>* layer = weights ? &weights->layers[layer_idx] : nullptr;
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(layer_idx);
for (int layer_idx = 0; layer_idx < static_cast<int>(TConfig::kLayers);
++layer_idx) {
const size_t idx = static_cast<size_t>(layer_idx);
Layer<TConfig>* layer = weights ? &weights->layers[idx] : nullptr;
CompressedLayer<TConfig>* c_layer = c_weights.CLayer(idx);
snprintf(name, sizeof(name), "pre_ff_ns_%lu", layer_idx);
snprintf(name, sizeof(name), "pre_ff_ns_%d", layer_idx);
func(name, layer ? layer->pre_ffw_norm_scale.data() : nullptr,
c_layer->c_pre_ffw_norm_scale);
snprintf(name, sizeof(name), "gating_ein_%lu", layer_idx);
snprintf(name, sizeof(name), "gating_ein_%d", layer_idx);
func(name, layer ? layer->gating_einsum_w.data() : nullptr,
c_layer->c_gating_einsum_w);
snprintf(name, sizeof(name), "linear_w_%lu", layer_idx);
snprintf(name, sizeof(name), "linear_w_%d", layer_idx);
func(name, layer ? layer->linear_w.data() : nullptr, c_layer->c_linear_w);
snprintf(name, sizeof(name), "qkv_ein_%lu", layer_idx);
snprintf(name, sizeof(name), "qkv_ein_%d", layer_idx);
func(name, layer ? layer->qkv_einsum_w.data() : nullptr,
c_layer->c_qkv_einsum_w);
snprintf(name, sizeof(name), "att_ein_%lu", layer_idx);
snprintf(name, sizeof(name), "att_ein_%d", layer_idx);
func(name, layer ? layer->attn_vec_einsum_w.data() : nullptr,
c_layer->c_attn_vec_einsum_w);
snprintf(name, sizeof(name), "pre_att_ns_%lu", layer_idx);
snprintf(name, sizeof(name), "pre_att_ns_%d", layer_idx);
func(name, layer ? layer->pre_attention_norm_scale.data() : nullptr,
c_layer->c_pre_attention_norm_scale);
}

29
gemma.h
View File

@ -26,15 +26,19 @@
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "configs.h" // kSeqLen
#include "configs.h" // kSeqLen
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase
#include "util/args.h" // ArgsBase
// copybara:end
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// copybara:end
namespace gcpp {
@ -118,21 +122,22 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file. (required)");
"Path name of tokenizer model file.\n Required argument.");
visitor(
cache, "compressed_weights", Path(),
"Path name of compressed weights file, regenerated from `--weights` "
"file if "
"the compressed weights file does not exist. (required)");
"the compressed weights file does not exist.\n Required argument.");
visitor(model_type, "model", std::string(),
"Model type - can be 2b-it (2B parameters, instruction-tuned), "
"2b-pt (2B parameters, pretrained), 7b-it (7B parameters, "
"instruction-tuned), or 7b-pt (7B parameters, pretrained). "
"(required)");
"Model type\n 2b-it (2B parameters, instruction-tuned)\n "
"2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters "
"instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n"
" Required argument.");
visitor(model, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if "
"compressed_weights file is not present and needs to be "
"regenerated. Otherwise, not needed");
"regenerated. This parameter is only required for compressing"
"new model weight exports, otherwise it is not needed.");
}
};
@ -186,10 +191,10 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", true,
visitor(multiturn, "multiturn", false,
"Multiturn mode (if 0, this clears the KV cache after every "
"interaction without quitting)",
2);
"interaction without quitting)\n Default : 0 (conversation "
"resets every turn)");
}
};

55
ops.h
View File

@ -214,7 +214,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
hn::Transform(D(), x, size, [](D d, hn::Vec<D> v) { return Gelu(d, v); });
hn::Transform(D(), x, size,
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
}
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
@ -567,22 +568,41 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size,
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;
const size_t N = hn::Lanes(d);
// Find max so we can subtract it below.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V max = vmin;
hn::Foreach(d, x, mask_pos, vmin,
[&max](D d, V v) { max = hn::Max(max, v); });
max = hn::MaxOfLanes(d, max); // broadcast
// Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors
// cannot be lambda-captured.
// TODO(janwas): could be replaced with an hn::Accumulate algo.
const hn::Vec<D> vmin = hn::Set(d, hwy::LowestValue<float>());
hn::Vec<D> vmax = vmin;
size_t idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
vmax = hn::Max(vmax, LoadU(d, x + idx));
}
}
vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx));
vmax = hn::MaxOfLanes(d, vmax); // broadcast
// Subtract max (avoid precision loss for large exponents) and exponentiate.
V sum = hn::Zero(d);
hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) {
const V out = hn::Exp(d, hn::Sub(v, max));
// Also avoid hn::Transform because the additional `sum` output vector cannot
// be captured by a lambda.
hn::Vec<D> sum = hn::Zero(d);
idx = 0;
if (mask_pos >= N) {
for (; idx <= mask_pos - N; idx += N) {
const hn::Vec<D> out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax));
sum = hn::Add(sum, out);
hn::StoreU(out, d, x + idx);
}
}
if (mask_pos > idx) {
const size_t remaining = mask_pos - idx;
const hn::Vec<D> out =
hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax));
sum = hn::Add(sum, out);
return out;
});
hn::StoreN(out, d, x + idx, remaining);
}
// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
@ -601,13 +621,12 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;
const V inv_cap = hn::Set(d, 1.0f / cap);
const V vcap = hn::Set(d, cap);
const float inv_cap = 1.0f / cap;
hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec<D> v) {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v)));
hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(hn::Set(d, cap),
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
});
}

81
run.cc
View File

@ -24,12 +24,16 @@
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
#include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp
// copybara:end
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -39,20 +43,13 @@
namespace gcpp {
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
fprintf(stderr,
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
"specify 3 required model loading arguments: --tokenizer, "
"--compressed_weights, "
"and --model.\n\nModel Loading Arguments\n\n");
loader.Help();
fprintf(stderr, "\nInference Arguments\n\n");
inference.Help();
fprintf(stderr, "\nApplication Arguments\n\n");
app.Help();
fprintf(stderr, "\n\n");
}
static constexpr std::string_view kAsciiArtBanner =
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
"| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |\n"
" \\__, |\\___|_| |_| |_|_| |_| |_|\\__,_(_)___| .__/| .__/\n"
" __/ | | | | |\n"
" |___/ |_| |_|";
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
@ -69,7 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
<< std::thread::hardware_concurrency() << std::endl
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
<< hwy::VectorBytes() * 8 << " bits)"
<< "\n"
<< "Weight Type : "
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
<< "EmbedderInput Type : "
@ -77,9 +75,31 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
std::cerr
<< kAsciiArtBanner
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run gemma.cpp, you need to "
"specify 3 required model loading arguments:\n --tokenizer\n "
"--compressed_weights\n"
" --model.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--compressed_weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
int verbosity, const gcpp::AcceptFunc& accept_token) {
int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
@ -137,7 +157,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
if (verbosity >= 1) {
std::cout << "> " << std::flush;
}
std::getline(std::cin, prompt_string);
if (eot_line.size() == 0) {
std::getline(std::cin, prompt_string);
} else {
std::string line;
while (std::getline(std::cin, line)) {
if (line == eot_line) {
break;
}
prompt_string += line + "\n";
}
}
}
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
@ -221,7 +252,12 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
const std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%Q quits).\n\n"
" Enter an instruction and press enter (%C resets conversation, "
"%Q quits).\n" +
(inference.multiturn == 0
? std::string(" Since multiturn is set to 0, conversation will "
"automatically reset every turn.\n\n")
: "\n") +
"*Examples*\n"
" - Write an email to grandma thanking her for the cookies.\n"
" - What are some historical attractions to visit around "
@ -230,13 +266,14 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
" - Write a standup comedy bit about GPU programming.\n";
std::cout << "\033[2J\033[1;1H" // clear screen
<< banner_ascii_art << "\n\n";
<< kAsciiArtBanner << "\n\n";
ShowConfig(loader, inference, app);
std::cout << "\n" << instructions << "\n";
}
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; });
ReplGemma(
model, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line);
}
} // namespace gcpp

View File

@ -18,14 +18,20 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#if HWY_OS_LINUX
#include <sched.h>
#include <cerrno> // IDE does not recognize errno.h as providing errno.
#endif
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::clamp
#include <thread> // NOLINT>
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
// copybara:end
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
@ -36,7 +42,13 @@ static inline void PinThreadToCore(size_t cpu_index) {
cpu_set_t cset; // bit array
CPU_ZERO(&cset); // clear all
CPU_SET(cpu_index, &cset); // set bit indicating which processor to run on.
HWY_ASSERT(0 == sched_setaffinity(0, sizeof(cset), &cset));
const int err = sched_setaffinity(0, sizeof(cset), &cset);
if (err != 0) {
fprintf(stderr,
"sched_setaffinity returned %d, errno %d. Can happen if running in "
"a container; this warning is safe to ignore.\n",
err, errno);
}
#else
(void)cpu_index;
#endif
@ -62,10 +74,10 @@ class AppArgs : public ArgsBase<AppArgs> {
Path log; // output
int verbosity;
size_t num_threads;
std::string eot_line;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(log, "log", Path{"/tmp/log.txt"}, "Logging file", 2);
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
@ -73,10 +85,16 @@ class AppArgs : public ArgsBase<AppArgs> {
2);
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads
"Number of threads to use. Default value is set based on an "
"estimate of "
"how many concurrent threads are supported.",
"Number of threads to use.\n Default = Estimate of the "
"number of suupported concurrent threads.",
2);
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
2);
}
};

View File

@ -204,7 +204,7 @@ class ArgsBase {
}
};
static bool HasHelp(int argc, char* argv[]) {
static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) {
// TODO(austinvhuang): handle case insensitivity
if (argc == 1) {
// no arguments - print help