From 9cdc9223bce51a88de74022f33666309556f14c6 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 27 Feb 2024 14:22:02 -0500 Subject: [PATCH] clean up formatting after 129e66ada2b4e461bdf28b88b70cd2465cb213e4, add .clang-format defaults, minor updates to DEVELOPERS doc --- .clang-format | 235 ++++++++++++++++++++++++++++++++++++++++++++++++++ DEVELOPERS.md | 12 +++ configs.h | 18 ++-- gemma.h | 51 +++++------ 4 files changed, 282 insertions(+), 34 deletions(-) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..c8f8dba --- /dev/null +++ b/.clang-format @@ -0,0 +1,235 @@ +--- +Language: Cpp +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignArrayOfStructures: None +AlignConsecutiveAssignments: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: true +AlignConsecutiveBitFields: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveDeclarations: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveMacros: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveShortCaseStatements: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCaseColons: false +AlignEscapedNewlines: Right +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 0 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortEnumsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: All +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: MultiLine +AttributeMacros: + - __capability +BinPackArguments: true +BinPackParameters: true +BitFieldColonSpacing: Both +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterExternBlock: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakAfterAttributes: Never +BreakAfterJavaFieldAnnotations: false +BreakArrays: true +BreakBeforeBinaryOperators: None +BreakBeforeConceptDeclarations: Always +BreakBeforeBraces: Attach +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakInheritanceList: BeforeColon +BreakStringLiterals: true +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IfMacros: + - KJ_IF_MAYBE +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + SortPriority: 0 + CaseSensitive: false + - Regex: '.*' + Priority: 1 + SortPriority: 0 + CaseSensitive: false +IncludeIsMainRegex: '(Test)?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseBlocks: false +IndentCaseLabels: false +IndentExternBlock: AfterExternBlock +IndentGotoLabels: true +IndentPPDirectives: None +IndentRequiresClause: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +InsertBraces: false +InsertNewlineAtEOF: false +InsertTrailingCommas: None +IntegerLiteralSeparator: + Binary: 0 + BinaryMinDigits: 0 + Decimal: 0 + DecimalMinDigits: 0 + Hex: 0 + HexMinDigits: 0 +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +KeepEmptyLinesAtEOF: false +LambdaBodyIndentation: Signature +LineEnding: DeriveLF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PackConstructorInitializers: BinPack +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakOpenParenthesis: 0 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyIndentedWhitespace: 0 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Left +PPIndentWidth: -1 +QualifierAlignment: Leave +ReferenceAlignment: Left +ReflowComments: true +RemoveBracesLLVM: false +RemoveParentheses: Leave +RemoveSemicolon: false +RequiresClausePosition: OwnLine +RequiresExpressionIndentation: OuterScope +SeparateDefinitionBlocks: Leave +ShortNamespaceLines: 1 +SortIncludes: CaseSensitive +SortJavaStaticImport: Before +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceAroundPointerQualifiers: Default +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeJsonColon: false +SpaceBeforeParens: ControlStatements +SpaceBeforeParensOptions: + AfterControlStatements: true + AfterForeachMacros: true + AfterFunctionDefinitionName: false + AfterFunctionDeclarationName: false + AfterIfMacros: true + AfterOverloadedOperator: false + AfterRequiresInClause: false + AfterRequiresInExpression: false + BeforeNonEmptyParentheses: false +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +SpaceInEmptyBlock: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: Never +SpacesInContainerLiterals: true +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParens: Never +SpacesInParensOptions: + InCStyleCasts: false + InConditionalStatements: false + InEmptyParentheses: false + Other: false +SpacesInSquareBrackets: false +Standard: Latest +StatementAttributeLikeMacros: + - Q_EMIT +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 8 +UseTab: Never +VerilogBreakBetweenInstancePorts: true +WhitespaceSensitiveMacros: + - BOOST_PP_STRINGIZE + - CF_SWIFT_NAME + - NS_SWIFT_NAME + - PP_STRINGIZE + - STRINGIZE +... + diff --git a/DEVELOPERS.md b/DEVELOPERS.md index bdc02c0..7aad9d8 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -71,6 +71,18 @@ The implementation code is roughly split into 4 layers, from high to low level: 4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of highway) supporting the implementations in (3). +Besides these layers, supporting utilities are: + +- `compression/` - model compression operations. the 8-bit switched floating + point model conversion is here. +- `util/` - command line argument handling and any other utilities. + +## Style and Formatting + +A `.clang-format` configuration is provided with our defaults, please run source +files through `clang-format` (or a formatter that produces equivalent behavior) +before finalizing PR for submission. + ## Compile-Time Flags (Advanced) There are several compile-time flags to be aware of (note these may or may not diff --git a/configs.h b/configs.h index 4be5f75..bf25596 100644 --- a/configs.h +++ b/configs.h @@ -21,7 +21,7 @@ // Allow changing pre-allocated kv cache size as a compiler flag #ifndef GEMMA_MAX_SEQLEN #define GEMMA_MAX_SEQLEN 4096 -#endif // !GEMMA_MAX_SEQLEN +#endif // !GEMMA_MAX_SEQLEN #include @@ -34,10 +34,10 @@ struct ConfigGemma7B { static constexpr int kVocabSize = 256128; static constexpr int kLayers = 28; static constexpr int kModelDim = 3072; - static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA - static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kKVHeads = 16; // standard MHA + static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = 1; }; @@ -46,13 +46,13 @@ struct ConfigGemma2B { static constexpr int kVocabSize = 256128; static constexpr int kLayers = 18; static constexpr int kModelDim = 2048; - static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; - static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support - static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support + static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = 1; }; -} // namespace gcpp +} // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ diff --git a/gemma.h b/gemma.h index 1e76a37..12c2a77 100644 --- a/gemma.h +++ b/gemma.h @@ -27,12 +27,12 @@ // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream // copybara:import_next_line:gemma_cpp -#include "configs.h" // kSeqLen +#include "configs.h" // kSeqLen // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" +#include "util/args.h" // ArgsBase // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" @@ -42,7 +42,7 @@ namespace gcpp { // float, hwy::bfloat16_t, SfpStream, NuqStream #ifndef GEMMA_WEIGHT_T #define GEMMA_WEIGHT_T SfpStream -#endif // !GEMMA_WEIGHT_T +#endif // !GEMMA_WEIGHT_T using WeightT = GEMMA_WEIGHT_T; using EmbedderInputT = hwy::bfloat16_t; @@ -51,9 +51,9 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim + key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr - value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim + value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim }; // Model variants: see configs.h for details. @@ -61,9 +61,9 @@ enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); } + LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - static std::string ToLower(const std::string &text) { + static std::string ToLower(const std::string& text) { std::string result = text; std::transform(begin(result), end(result), begin(result), [](unsigned char c) { return std::tolower(c); }); @@ -89,7 +89,7 @@ struct LoaderArgs : public ArgsBase { } // Returns error string or nullptr if OK. - const char *Validate() const { + const char* Validate() const { const std::string model_type_lc = ToLower(model_type); if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && model_type_lc != "2b-it" && model_type_lc != "7b-it") { @@ -111,11 +111,11 @@ struct LoaderArgs : public ArgsBase { } Path tokenizer; - Path model; // uncompressed weights OR - Path cache; // compressed weights + Path model; // uncompressed weights OR + Path cache; // compressed weights std::string model_type; - template void ForEach(const Visitor &visitor) { + template void ForEach(const Visitor& visitor) { visitor(tokenizer, "tokenizer", Path(), "Path name of tokenizer model file. (required)"); visitor( @@ -138,10 +138,10 @@ struct LoaderArgs : public ArgsBase { struct GemmaInterface; struct Gemma { - Gemma(const LoaderArgs &args, hwy::ThreadPool &pool); - ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - const sentencepiece::SentencePieceProcessor &Tokenizer() const; + const sentencepiece::SentencePieceProcessor& Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; @@ -153,7 +153,7 @@ using StreamFunc = std::function; using AcceptFunc = std::function; struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); } + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } size_t max_tokens; size_t max_generated_tokens; @@ -163,7 +163,7 @@ struct InferenceArgs : public ArgsBase { bool multiturn; // Returns error string or nullptr if OK. - const char *Validate() const { + const char* Validate() const { if (max_tokens > gcpp::kSeqLen) { return "max_tokens is larger than the maximum sequence length (see " "configs.h)."; @@ -175,7 +175,7 @@ struct InferenceArgs : public ArgsBase { return nullptr; } - template void ForEach(const Visitor &visitor) { + template void ForEach(const Visitor& visitor) { visitor(max_tokens, "max_tokens", size_t{3072}, "Maximum number of tokens in prompt + generation."); visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, @@ -186,19 +186,20 @@ struct InferenceArgs : public ArgsBase { "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)\n Default = 0 (conversation resets every turn)"); + "interaction without quitting)\n Default = 0 (conversation " + "resets every turn)"); } }; -void GenerateGemma(Gemma &gemma, const InferenceArgs &args, - const std::vector &prompt, size_t start_pos, - hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool, - const StreamFunc &stream_token, - const AcceptFunc &accept_token, std::mt19937 &g, +void GenerateGemma(Gemma& gemma, const InferenceArgs& args, + const std::vector& prompt, size_t start_pos, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& g, int verbosity); constexpr int EOS_ID = 1; -} // namespace gcpp +} // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_