Move the backpropagation code to its own directory

This commit is contained in:
Zoltan Szabadka 2024-06-04 10:20:16 +00:00
parent 3b4fa4a0e3
commit df01700b54
19 changed files with 52 additions and 49 deletions

View File

@ -46,26 +46,26 @@ set(SOURCES
compression/sfp.h compression/sfp.h
compression/sfp-inl.h compression/sfp-inl.h
compression/test_util.h compression/test_util.h
backprop/backward.cc
backprop/backward.h
backprop/backward-inl.h
backprop/backward_scalar.h
backprop/common_scalar.cc
backprop/common_scalar.h
backprop/forward.cc
backprop/forward.h
backprop/forward-inl.h
backprop/forward_scalar.h
backprop/optimizer.cc
backprop/optimizer.h
gemma/configs.h gemma/configs.h
gemma/activations.cc gemma/activations.cc
gemma/activations.h gemma/activations.h
gemma/backward.cc
gemma/backward.h
gemma/backward-inl.h
gemma/backward_scalar.h
gemma/common.h gemma/common.h
gemma/common-inl.h gemma/common-inl.h
gemma/common_scalar.cc
gemma/common_scalar.h
gemma/forward.cc
gemma/forward.h
gemma/forward-inl.h
gemma/forward_scalar.h
gemma/gemma.cc gemma/gemma.cc
gemma/gemma.h gemma/gemma.h
gemma/ops.h gemma/ops.h
gemma/optimizer.cc
gemma/optimizer.h
gemma/weights.cc gemma/weights.cc
gemma/weights.h gemma/weights.h
util/app.h util/app.h
@ -122,11 +122,11 @@ enable_testing()
include(GoogleTest) include(GoogleTest)
set(GEMMA_TEST_FILES set(GEMMA_TEST_FILES
backprop/backward_test.cc
backprop/backward_scalar_test.cc
backprop/optimize_test.cc
gemma/ops_test.cc gemma/ops_test.cc
gemma/gemma_test.cc gemma/gemma_test.cc
gemma/backward_test.cc
gemma/backward_scalar_test.cc
gemma/optimize_test.cc
) )
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)

View File

@ -26,8 +26,8 @@
#include <array> #include <array>
#include <cmath> #include <cmath>
#include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/prompt.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -13,17 +13,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/backward.h" #include "backprop/backward.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/backward.cc" // NOLINT #define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "gemma/backward-inl.h"
#include "gemma/weights.h"
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {

View File

@ -18,8 +18,8 @@
#include <vector> #include <vector>
#include "backprop/prompt.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/prompt.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {

View File

@ -23,9 +23,9 @@
#include <complex> #include <complex>
#include <vector> #include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common_scalar.h"
#include "gemma/prompt.h"
#include "gemma/weights.h" #include "gemma/weights.h"
namespace gcpp { namespace gcpp {

View File

@ -13,15 +13,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/backward_scalar.h" #include "backprop/backward_scalar.h"
#include <array> #include <array>
#include <complex> #include <complex>
#include <random> #include <random>
#include "gemma/forward_scalar.h" #include "backprop/forward_scalar.h"
#include "gemma/sampler.h" #include "backprop/sampler.h"
#include "gemma/test_util.h" #include "backprop/test_util.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace gcpp { namespace gcpp {

View File

@ -25,27 +25,27 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "backprop/backward_scalar.h"
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/compress.h" #include "compression/compress.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "gemma/backward_scalar.h"
#include "gemma/forward_scalar.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/sampler.h"
#include "gemma/test_util.h"
#include "gemma/weights.h" #include "gemma/weights.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/backward_test.cc" //NOLINT #define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
// After highway.h // After highway.h
#include "gemma/backward-inl.h" #include "backprop/backward-inl.h"
#include "gemma/forward-inl.h" #include "backprop/forward-inl.h"
#include "gemma/ops.h" #include "gemma/ops.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();

View File

@ -16,10 +16,11 @@
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT #define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" // IWYU pragma: keep #include "hwy/highway.h"
// After highway.h
#include "gemma/ops.h" #include "gemma/ops.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();

View File

@ -13,16 +13,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/forward.h" #include "backprop/forward.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT #define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" // IWYU pragma: keep #include "hwy/highway.h"
#include "gemma/forward-inl.h" // After highway.h
#include "backprop/forward-inl.h"
#include "gemma/weights.h" #include "gemma/weights.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();

View File

@ -18,8 +18,8 @@
#include <vector> #include <vector>
#include "backprop/prompt.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/prompt.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {

View File

@ -23,9 +23,9 @@
#include <complex> #include <complex>
#include <vector> #include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common_scalar.h"
#include "gemma/prompt.h"
#include "gemma/weights.h" #include "gemma/weights.h"
namespace gcpp { namespace gcpp {

View File

@ -16,12 +16,12 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "backprop/backward.h"
#include "backprop/forward.h"
#include "backprop/optimizer.h"
#include "backprop/sampler.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/backward.h"
#include "gemma/forward.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/optimizer.h"
#include "gemma/sampler.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"

View File

@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/optimizer.h" #include "backprop/optimizer.h"
#include <random> #include <random>

View File

@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include "gemma/prompt.h" #include "backprop/prompt.h"
namespace gcpp { namespace gcpp {