diff --git a/gemma/attention_test.cc b/gemma/attention_test.cc index 19e1f06..21f65e7 100644 --- a/gemma/attention_test.cc +++ b/gemma/attention_test.cc @@ -42,8 +42,9 @@ namespace gcpp { namespace HWY_NAMESPACE { void FillRandom(MatPtrT& mat, uint64_t seed) { - hwy::RandomState rng(seed); + hwy::RandomState rng0(seed); for (size_t r = 0; r < mat.Rows(); ++r) { + hwy::RandomState rng(rng0()); float* row = mat.Row(r); for (size_t c = 0; c < mat.Cols(); ++c) { row[c] = static_cast(RandomGaussian(rng)); @@ -143,7 +144,9 @@ struct TestAttentionState { double GetTolerance() { const char* target_name = hwy::TargetName(HWY_TARGET); - if (strncmp(target_name, "AVX2", 4) == 0) { + if (strcmp(target_name, "EMU128") == 0) { + return 1e-2; // Flash and Old don't agree sometimes! + } else if (strncmp(target_name, "AVX2", 4) == 0) { return 2e-2; } else if (strncmp(target_name, "AVX3", 4) == 0) { return 3e-4; @@ -262,276 +265,292 @@ const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats // Layer 0 const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = { - {{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5, - 26.875, 63, 3.34375, -67.5, 31.125, -190, 125}, - {-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125, - -14.9375, -109.5, 76, 9.25, -142, 29.5, -105}}, - {{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875, - 128, 27.25, -161, 19.125, -58, 97.5}, - {-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5, - 102.5, -184, 52.75, 83.5, -71, 46.5, -52}}, - {{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125, - 6.90625, 150, 144, -155, -47.25, -98.5, 3.5625}, - {-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51, - -66.5, -0.84765625, -46.5, -152, -2.9375, -81}}, - {{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75, - -46, -22, -19.375, -16.125, -148, 20.875}, - {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, - 33, 10.9375, -52.5, 23.25, 75}}, - {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25, - -34.75, 18, -52, 100, -186, -75.5, 50.75}, - {7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, - 39.25, 65, 47.25, -89.5, -34.25, 137}}, - {{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5, - 32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132}, - {143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5, - 13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}}, - {{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25, - -117, 26.375, 39.5, 125, 66, 48.75, 20.75}, - {137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5, - 88.5, 48.5, -46.25, -36.75, 101.5}}, - {{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5, - 125, -53.5, -14.625, 262}, - {29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172, - 81, -27.25, -3.03125, 111, -167, 59, 176}}, - {{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25, - -52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5}, - {40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5, - 58.25, 173, -53.5, 25.125, 4.8125}}, - {{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5, - -17.125, 15.25, 143, -27, 103, 101}, - {-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625, - 8.125, -99.5, 13.6875, -11.6875, 33}}, + {{-107, 32.25, 70.5, -70, -130, -41.5, 142, 98.5, -7.03125, 39.75, -51.5, + 43.25, 18.125, 152, 61, 56, 27.25}, + {-132, -53.5, -48.75, -52.5, 1.015625, -24.5, 226, -53.75, -26.75, + 1.6484375, -12.75, 68, 107, 92.5, 46.75, -36.25, 118}}, + {{-41.5, 50, 3.953125, -37.75, 158, -22.25, 35, 27, -59.5, 67, -12.5625, + -23, -9.8125, 58.25, 54, -123, -39.75}, + {55.75, -0.859375, -148, 36.5, 48.75, 33, 205, -25.375, 110.5, 63.5, 88.5, + 50.25, 106, 156, 18.125, -20.5, 70}}, + {{-21, 35, 141, -1.4609375, 14.5, -37.25, 104.5, 86, -56.5, 78, 119, 19.625, + -19.875, 227, 58, 19, 38}, + {-71.5, -59, -116.5, 13.5625, -71, -94.5, 67.5, -54.5, -28.875, 87, 161, + 66.5, 131, 86.5, 104, -79.5, 1.8984375}}, + {{-55.5, 65, -83, 17.75, 41.25, -16.125, -175, -22.875, 62, -46.75, 182, + 16.25, 45.5, 84.5, -32.75, 40.25, 108.5}, + {-15.75, -59.75, 99.5, -43.5, 35.5, 76.5, 73, 173, -37.25, 70, -31, 103.5, + -27.375, -9, 71, -62, -174}}, + {{-90.5, -8.25, 75.5, -117, 68, 11.4375, -90, 47.75, -48.25, 8.9375, 25.5, + 79.5, -39.25, 102, 66, -63.5, -44.5}, + {63, 80.5, -59, 81.5, -71, 190, 67.5, -46.75, 10.5625, 100, 123.5, 101.5, + -50.75, -24.25, 80.5, 31.75, -8.3125}}, + {{-148, 20.125, 24, 110, -148, 2.5625, -117.5, 1.609375, -67, -91.5, 105, + 151, 203, -23.25, 64.5, 21.625, -51.5}, + {-49.25, 87, -97.5, 21.625, -231, 42.5, 117.5, -70.5, 4.71875, 118, 68, + 69.5, 15.4375, 88.5, -67.5, -17.625, -13.5625}}, + {{-92, 38.5, -21.875, 165, -32.25, -108.5, -143, 36.75, 49.5, 11.1875, 70, + 17.75, -16.125, 151, -191, -22.625, 49}, + {-36.75, -24.25, 42.75, -19.125, -118, -220, 169, -97, -75.5, 19.25, + -41.25, 107, -24.75, 157, 99, 54, -129}}, + {{-35.5, -34, -34.75, -63.25, -70, -22.375, -66, 232, -74, -54, 125, -67.5, + -109.5, 119, 101, -98, 22.5}, + {-17.875, -88, 0.8671875, 55, -42, -53.25, 114, 26.125, 87.5, 27.375, + -27.75, 18.125, 75, 26.25, -35.75, 20, 193}}, + {{52, 46.25, 28.875, 66.5, 119, -7.59375, -40.5, 135, -6.4375, 57.75, 97.5, + 30.375, -153, -17.5, 2.359375, -82.5, -39.25}, + {8.75, 46, -66, 4.6875, -111, 196, -50, -106.5, -71.5, 43.5, 12.375, + -50.75, 71, 52.75, -17.625, -78.5, -172}}, + {{-96, -22.5, -4.96875, -4.21875, -77, -67.5, -28, -12, 14.3125, -44, 72.5, + -43.5, 34, 29, 67, 10.625, 40.25}, + {-110.5, -1.2734375, 101, 78.5, -116.5, -125.5, 172, 49, 1.078125, -50.25, + -33.5, -3.59375, -19.625, -13.625, -14.875, 39, 115}}, }; // Layer 0, *K*V Head 0 const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = { - {{-4.51717567, 6.93118095, 6.48003578, 9.12825584, 2.38755274, 11.8121576, - 1.65376127, 5.04456615, -7.19549274, 2.57609844, 3.55331731, -3.48494458, - -8.90498638, 9.66047478, -0.379868984, 6.37043715, -2.24351144}, - {0.152208567, 3.14520073, -8.35154343, 5.44226503, -6.74000502, - -1.43484437, -4.72092056, -9.48932, -6.12409401, -1.55352509, -3.90701318, - 2.12124252, 3.93649936, -8.09877586, -3.30277514, -0.898857355, - 1.76684189}}, - {{4.378829, 5.05565643, -7.63948059, -5.74608946, 2.90109587, 0.155819178, - 4.56115055, 1.37885749, 1.48427355, -1.07145202, 2.82399392, -1.20864201, - 3.05434561, -2.65185618, -0.0731391, -8.2279253, 7.63228416}, - {-0.702698231, 1.49563932, 6.42149782, -6.68306589, 1.85317755, - -7.70267582, 2.07357907, -7.60303402, -0.514724255, 0.308567047, - 5.99250412, -4.67359257, -3.49322176, -2.62086344, -3.18411255, - 2.04027057, -4.29057407}}, - {{-1.20844436, 4.14724302, 6.04515219, 8.7753458, -0.975198627, 0.564640105, - 5.39941597, 4.64036179, 0.366614938, 3.48258138, -0.470701456, 15.2267399, - 4.63302803, 9.12662697, -5.89148045, 2.25731587, 5.24449492}, - {4.57078934, -4.60315752, -3.3364439, 1.29875994, -3.40833569, -6.95262, - -6.39040232, -6.60212612, 6.63269806, -0.815209687, -5.0346446, - -4.13564968, 8.25674057, -6.0910182, -8.21130085, -8.91020393, - 10.6188011}}, - {{0.602011144, 2.22505236, 3.62411499, -4.07026958, 12.8036356, 3.76139069, - 6.99502087, 7.02500725, -2.51568675, 4.2489934, 0.00210827589, - -1.43267739, -2.10394144, -0.0506809056, -1.54883039, 4.3740139, - -1.61869526}, - {-6.37204599, -3.34989691, 2.10935307, 4.23634195, 5.79134035, 13.502944, - -2.19158888, -1.55771351, -1.22244942, 3.36499929, -2.11375904, - -4.5448761, 1.0611912, -2.47849369, -0.212709218, 0.363292456, - 7.91467094}}, - {{-8.85739231, -4.08585882, -0.618261, 6.52911091, 5.14922285, 7.6869874, - 0.750387549, -0.812200725, 2.7509625, 6.29693508, -1.77248931, 5.68896484, - -6.9369607, -4.61359406, 0.184977874, -1.27769828, -2.1619854}, - {-8.2555, 2.84032059, -1.03791106, 2.07648611, -4.94546843, 1.76888537, - -1.75901175, 11.2628574, 1.41086221, -3.58669901, -2.85925198, 2.29133463, - 1.55509436, -0.0553357825, -10.0363655, 1.94261, -2.95691729}}, - {{0.919141412, 1.97533965, -11.3202848, -3.3137629, -4.7161727, 5.07012081, - 1.76256621, 8.20588207, 6.05700159, -3.89765406, -1.13639557, -1.32326794, - -3.01544905, -0.585309267, 2.60637712, 2.83708405, -3.39202118}, - {9.11918, 2.11261511, -5.87290621, 11.6033278, -4.66597795, -7.13774204, - -9.10563755, -2.48294282, 3.35282946, -3.75122213, 0.404774547, - -9.11625195, 4.85711479, 1.43184578, 1.47673059, -4.75093, -3.45323014}}, - {{4.17705393, -4.95192289, -10.5068378, 3.90004015, -3.51306129, 5.38068056, - 0.901511431, 11.222868, 2.67285442, 9.18779, 5.61346769, 3.06534624, - -3.78898215, 0.767340839, 15.8207836, -4.14079094, -4.63177109}, - {3.61795235, -7.00262165, 2.08284521, -6.70515728, 1.93205631, 2.84467721, - 3.94591737, -6.18882942, -1.78465152, -9.39100933, -10.8780289, - 6.32468653, 6.53142738, -3.30765963, 2.89132166, 4.53347206, 1.89792418}}, - {{-0.361971855, -1.57735932, 5.07296801, -1.55669761, -1.44996238, - 7.29838896, 5.23075104, -0.512441278, -3.59834242, 2.38584423, 6.48518324, - -1.48220074, -2.4264791, 10.7237988, 5.64735842, 5.6251297, -7.04244423}, - {-0.795628309, 7.30230665, -1.71035647, -16.6999454, 3.05102086, - -4.9243927, 4.28508186, -0.694577456, 6.58464718, 4.40330124, 3.3250041, - 1.90579033, -6.29048729, 2.55308104, -4.9746747, -0.681708, -5.98152351}}, - {{2.57555652, -3.5651083, 0.784440041, -4.7043705, 2.37520599, -3.62385964, - -3.48913693, -7.28049421, -5.48726082, 1.95519221, 7.25192928, 3.07074118, - -11.9897156, 5.92244673, 5.07564354, 0.162699938, -6.00809956}, - {5.56260443, -5.7683115, 1.26402235, -17.507719, 4.18873024, -3.20694613, - -4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766, - -7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187, - 2.60931611}}, - {{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, - -16.3434925, -4.75156116, -1.99114823, 3.99918842, -5.95400572, - 10.8700314, 1.07596064, 0.30389142, 8.39548779, -5.11913681, 5.45641088, - -5.63240337}, - {-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355, - -1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147, - -5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052, - -7.43491}}, + {{15.2907486, -9.24563789, -1.87377763, -1.6078732, -2.52019691, 3.78340316, + 1.56531, -0.419910669, 0.0457177162, 1.7699399, 0.973267794, -11.2898827, + 3.79524374, 3.8804853, 8.05621147, 1.64328313, -7.22062826}, + {-1.33305621, -1.20374441, 5.16571712, -0.245627165, 1.00112915, + -3.94195318, -1.53855979, -2.24500442, 4.81447029, -8.42467785, 2.6451962, + -5.42961216, -1.04181266, -6.57116222, -2.43039203, -9.50760841, + 3.21791911}}, + {{1.35395038, -0.375163317, 2.66030908, -3.00428605, 6.10236216, + -10.4410543, -1.12052476, 5.85763407, -0.0452268124, -2.42987514, + 6.85442591, 1.17080283, -3.25781202, 6.65555668, -5.64402437, 4.7492609, + 9.98779583}, + {-11.0549402, -10.9070759, -9.21442795, 8.93494606, -0.663663864, + -0.127197742, -0.418648839, -0.12933588, 10.0827341, 13.9710932, + -7.22307491, -2.81767416, 2.61202765, -10.5902529, -1.11884749, + -0.00246357918, 2.00061131}}, + {{-4.12993002, 3.06688476, -3.34329081, 0.188707948, 2.42000532, + -0.339237094, 5.88325405, -2.4620254, 3.93701172, -0.949787855, + -3.56888604, -4.52016211, -6.81539917, 3.83921003, -1.64406776, + -4.28217793, 4.09804487}, + {9.04821014, -6.12610292, -3.91204882, 2.46237516, 2.26863813, + -1.05252552, 0.674160719, -0.543522477, 0.315010548, -6.30216789, + -7.87714481, 2.71428013, 6.90030003, 8.48286819, -3.15425754, 5.1051693, + 2.59031558}}, + {{3.85839581, -4.56797647, -5.07595825, -0.837815881, -3.84364843, + -5.15372133, -0.232586145, 7.362432, 0.107376553, 2.64676356, 0.902205765, + -7.68729115, -1.04463434, -7.04473209, -2.12464309, -2.62663937, + 2.3179245}, + {-10.2786751, -7.18292856, -1.0349617, 5.58713627, -4.24747801, + -0.505107284, -3.58366871, -5.82409763, 1.5151974, 3.69901705, + 0.225643635, -1.91915131, -9.39223576, -2.99991035, 3.88195848, + -0.975675821, 9.08020401}}, + {{0.713129759, 0.831702948, 4.85394859, -1.3690424, 1.06993294, 1.77343011, + 4.4732461, 2.77546239, 1.76154709, -10.2734528, 4.89345741, 1.56878746, + 0.557243943, 2.686064, -0.480260491, -1.30898976, 7.84716129}, + {-0.48303628, 1.8997345, 9.41060734, -1.07365155, 16.2980633, 0.842305303, + 1.46111321, -5.46785688, 9.73378944, 1.76110291, -0.617839932, + -0.699874997, -6.00970268, -2.25671721, -4.34198618, 10.7963381, + -1.31340837}}, + {{0.839338958, -0.991259813, 2.44353271, 5.51663303, -4.78505135, + -4.73743773, -6.66635752, -12.1987858, 0.619547904, 1.12478662, + -2.90830898, 3.32718873, -5.1365242, 0.0782394409, 6.71992254, + -1.30097711, -10.1333361}, + {-4.03514862, -1.19420063, -0.467277795, -7.10551929, -2.79278111, + -5.32330513, 4.69234657, 1.59959948, -10.0435543, -0.308479786, + 2.11825275, -3.33224726, 1.42422175, 10.0299196, 3.14650702, -4.50784397, + 1.13975036}}, + {{-7.77441454, 6.60742712, -3.2969532, 4.07419205, 0.553794742, + -0.980163574, -0.80379802, 5.47732353, -2.80931783, -7.27533054, + 1.96269298, 0.103360891, 11.9011269, -1.67654371, -4.00289297, + 3.95645094, 6.72452736}, + {-2.08075809, -0.622131109, 6.95990324, -10.1613321, -6.5728159, + -1.83433318, -7.4444685, -1.17990899, 0.949428558, -7.08294106, + 6.8835268, 0.593178153, -1.11343932, 11.1121941, -3.24285984, + 5.95768023, 1.86565471}}, + {{-3.98357534, -5.07885265, 2.99530745, 2.21132183, -5.06690884, + 7.19524574, -8.69441986, -5.43023586, 3.60415602, 6.77679777, + 7.39095974, -11.7769651, -1.51282454, 10.512928, 8.33419418, + -4.89421844, 0.684614658}, + {3.33132195, -2.80186033, 7.80674505, -3.47060919, 1.73025632, + -3.24225068, -5.88360023, 5.90776682, -1.00811982, 9.21799469, + -0.796300411, -6.04880476, -2.39337349, 1.74686813, 7.84074497, + -1.17035842, -3.03220415}}, + {{-2.54733372, 7.53344202, 4.13780975, -9.24725914, -8.49006271, + -6.72345352, -1.11408019, -0.0324454904, -2.94914579, 3.31400394, + -2.5422883, 4.42092514, -2.48425007, -1.06791162, 0.47528255, + -5.99708033, -1.02899408}, + {1.68688703, 4.75695753, 5.33531904, -2.97416735, -2.4486413, + -8.94855595, -2.54400206, -0.263463914, 7.70630169, -2.4543817, + 0.341010422, -6.5072546, 6.57980537, -8.83047295, -5.90621185, + -1.36317229, -8.00853157}}, + {{-5.81304836, 7.35501003, -1.7505573, 4.28803205, -0.106060743, + 7.27207994, 3.63292217, 3.05916095, 2.7457571, 0.898360848, 6.84973812, + 0.0843296051, 6.84243679, 9.31108475, -0.37638694, -3.97468519, + 0.128682166}, + {-0.340807438, -3.57352829, 2.74731278, -8.07462502, -2.55854392, + -0.0783569366, 9.2572813, -2.07895994, -1.34830523, 0.524608493, 1.701473, + -6.40128899, -2.29863024, -0.430005044, -1.20804024, 7.26425266, + 8.14774704}}, }; // Layer 0, K*V* Head 0 const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = { - {{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423, - -1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332, - -4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026}, - {-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, - 3.52626801, -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, - 1.42195463, 0.301399827, -4.40214968, -2.12298298, 9.27825642, - -0.690600872}}, - {{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635, - -0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, - 9.5985508, -10.6630878, -11.9006901, 0.851743698, 0.581826329, - 5.21927929}, - {-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228, - 3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248, - -7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251, - 3.30500841}}, - {{-3.34342527, 6.03099537, 6.335958, 0.993818045, 0.905343294, 6.93058586, - 3.9635396, 10.8044815, 7.8620863, -10.1157322, -3.92666101, -0.183003783, - -5.27309418, -1.45110512, -8.96734, -2.63866425, 2.19913912}, - {16.416317, -1.62025332, 2.3161006, 3.32571959, -1.79581594, -10.2925539, - -5.86338425, -6.36642933, 9.18872166, 5.95524168, 6.38640785, 8.23832, - -6.57342291, -14.2017632, 1.10925388, 4.27255058, -2.65661311}}, - {{6.58254147, -6.96165133, -4.97437, -2.33467388, 5.83671236, -0.794236898, - -2.03117108, -3.93387103, -5.96872902, 5.83316422, 3.01795, -4.05260706, - -4.39556885, 3.24399853, 10.1573639, 4.71967888, 0.274738848}, - {7.13243389, -8.04649162, 2.53055143, 2.0771277, -0.667295456, -13.0285645, - 0.960428238, -2.11983275, 8.18105602, -6.72609901, -5.46944714, - 0.204244614, 0.0900330544, 8.86620903, 4.63697529, 3.19756651, - 2.99392676}}, - {{9.52539158, -4.3840766, -6.94514465, -2.75913763, -10.8364506, - -3.95606327, 2.43603897, -5.78482246, -0.801304817, 8.23436832, - -7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822, - 16.0471916}, - {6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875, - 4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, - 2.44881392, 1.99794042, -9.19855404, -4.02383137, -3.63013959, - -5.65853405}}, - {{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, - 11.8203125, 1.81672478, -1.42535269, -5.26496315, -5.31612349, - -4.19499826, 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, - 3.5296216}, - {7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912, - 0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329, - 3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}}, - {{-2.48950672, -8.55112743, 8.04663277, -5.77116871, -0.637019753, - -7.65882111, -7.49037457, 3.8041625, -3.57038307, 9.37715435, -6.42604256, - 1.62610793, -1.54000568, 2.52110147, 5.30775261, -4.10454893, - -4.96251774}, - {-2.95554614, -5.18210888, 1.00015664, -4.03864431, -7.14954519, - 5.99929142, 5.86350155, 2.03810191, -4.23009968, 9.39885902, -5.68198299, - 2.72845244, 11.7133255, 0.838779449, -13.2235403, 2.94607735, - -2.7902379}}, - {{2.86876941, -0.836064458, -0.374509573, -0.277966499, 3.20654631, - -3.68510771, -7.76134634, 2.23905277, -8.35530376, 5.25071716, - -1.38490796, -2.93542218, 0.509032726, -3.57361269, -2.82580233, - -4.49954033, 2.91235542}, - {-4.37938213, 4.78577232, 2.03453469, 5.48564529, -1.05589461, -1.65940428, - 4.0130887, 5.26074123, 4.67537832, 0.791350365, 6.3880868, 2.50402451, - 7.6603322, -3.16343474, -2.71949649, 4.61576128, 1.3817997}}, - {{0.289200783, 7.06031752, -1.15099299, -5.29136801, -1.343642, -8.36283112, - 4.13158274, -1.93137062, 3.16199875, 2.21854591, 2.18270063, 0.77002573, - 6.90393353, -0.644045949, -5.62211609, -1.09085155, 1.07821059}, - {-3.04716778, -2.52233481, -5.99031925, 2.80152273, 0.340899587, - 0.667474508, -2.39674735, 8.83768654, -5.45613146, -1.55994594, -2.216362, - 1.49354, -4.27255821, -9.05310917, 5.90691471, -1.29772806, -8.50278}}, - {{-3.1383903, -7.71573353, 3.38072681, 6.07642221, -2.39587545, -7.84178352, - -1.60108304, -8.6121521, -5.151721, 4.17612457, -2.86532378, 1.64645958, - -0.37970829, -4.34561253, -0.454322815, 0.331385136, -5.74550819}, - {4.77026033, -5.51171303, -7.38155365, -5.38462543, 2.95842505, 5.18372536, - 0.521988213, 7.23966122, -4.90852165, 7.18465281, 2.99289083, 10.0519466, - -2.09695673, 7.34368706, -2.40495348, 3.61603308, 0.131510735}}, + {{-6.57186985, 4.65591288, -2.99893808, -11.1538782, 0.244000077, + -10.2325764, 0.103694201, 0.521099567, -3.99905825, -1.62405348, + -4.68134117, 6.1718998, -1.34258807, 0.629202843, 2.19743776, + -0.994996071, 11.087513}, + {6.30781364, -0.809881091, -2.28015828, 7.74938059, -7.73279285, + -2.66831946, -2.19984651, 0.331729531, -2.91752172, -3.65055728, + 5.1676836, -7.56936884, 1.81388354, -4.26828051, -2.01722169, + -0.324608445, 7.27558804}}, + {{9.26140213, -3.20177221, 0.539388776, 5.40602064, -0.743577957, + 0.394759417, 9.85691643, 2.08870316, 0.901947498, -1.50783658, 2.25597, + -1.95216775, 0.435392141, -0.702769041, -4.18087959, -4.37605, + 7.78122902}, + {-2.62402225, 1.53574657, -8.48229218, -6.17764902, -8.80739498, + 3.71258497, -0.00548219681, 1.16554821, 4.96417856, 12.0105095, + -9.01848125, 0.977133036, 2.64647341, 6.30225754, 1.42601275, + -9.98334408, 0.879288554}}, + {{5.97513628, -6.88194704, -1.16571558, -4.31768417, -1.14049578, + -2.82398677, -6.27558422, 2.18296051, 2.75785732, 5.18285942, -4.07532883, + -7.07251263, 1.9271419, -8.29465675, -6.54444408, 7.7866087, -2.06813526}, + {-5.63859415, 5.49219513, 1.35068834, 3.48846531, 6.94235802, + -4.82062531, 2.47111416, -3.67039084, -2.86166239, 1.72953558, + -6.94025803, 3.77951097, 3.43053484, -0.421885848, -5.10398674, + 7.37130451, 4.5244031}}, + {{-2.85907269, -7.74109554, -2.99573851, -5.80393362, -4.41116858, + -2.96661329, 0.529096365, 4.8533392, -0.586824358, 10.2085228, + -7.89174175, 11.6699429, 3.2624352, -5.73311234, -7.5428834, -1.59121943, + -1.98875427}, + {-2.01318312, 7.78195047, 4.17572403, -0.517796278, 1.73962998, + 6.3888917, -1.03050208, -4.90732288, -2.38260913, -0.94410181, + -2.34225774, -5.66976643, 0.0630166531, 5.22525358, 9.27637863, + -6.84555054, -3.40093827}}, + {{0.713163614, -5.89050484, -1.6664927, -1.1432848, 10.4444027, 8.94331741, + -9.69797707, 4.2944026, -6.69290638, 4.72696638, -3.301085, 5.89265633, + 0.634907007, -6.85523701, 7.27885437, -12.9960146, 1.07775009}, + {-2.30551481, -0.188415289, 5.51380777, -2.09227371, -6.09918642, + -2.92235994, -10.7518473, 6.63548946, 6.40411043, 0.495648265, + -0.0361406803, -2.30997944, 2.38069057, -2.46818423, 0.144803047, + 4.35358715, -1.88418579}}, + {{0.667566538, -8.06617641, -2.83349943, -8.64362812, -5.26301479, + -4.63245106, -3.16837788, -1.80521441, -4.16031981, -4.48559904, + -8.40764809, 5.11661053, -3.19849682, 2.49863052, -0.809394836, + 8.11068916, -2.00028992}, + {-3.50956917, 3.46693277, 3.71822405, 6.78018379, 0.00734519958, + 7.60286093, 9.44774818, -0.519026041, -9.28906822, 7.12584591, + 8.12778854, 2.7033093, -2.83234954, -2.78084874, 3.65403628, + 0.215320587, -1.59024906}}, + {{10.8861427, 6.59166813, 9.7520752, -2.61776686, -0.63697052, 0.804175496, + 6.64749336, 2.6748116, 11.0225735, -5.11313915, -0.60951817, 1.94157505, + 0.709332824, -1.59864545, 9.76169205, -5.00956106, -1.29182816}, + {-6.48537397, -0.724315166, 3.55528641, 10.9164925, -10.999507, + -1.26528633, -1.44988942, -5.16796589, -0.320435524, 1.20271659, + -5.23793507, -1.60932314, 0.000490188599, -2.50121546, -5.03053236, + 7.05981207, 0.729410648}}, + {{0.193905115, -10.848567, 5.42073679, -3.42887449, -1.63425016, + -7.5447526, -2.04208255, -3.28060675, -0.136736155, 1.27700531, + 0.377272248, -1.60267282, -5.29419708, -2.20173168, 11.6071215, + 2.40224266, -4.04324436}, + {-3.08099103, 7.33237839, -11.2906342, 0.958051205, -4.04783964, + -1.28419411, 4.54195166, 5.41813755, -1.85887122, -5.0294466, + -5.22293329, 6.89848137, 1.11226559, -3.14861584, -3.68246865, + 3.34404039, 2.97509623}}, + {{-0.759357333, 1.27064419, -2.41022944, -5.52269745, 2.91421509, + -0.782507896, -0.228662491, 4.27539682, 2.97740626, 12.5008287, + -9.4860878, 1.21384573, 9.5913868, 5.45113611, 0.403315663, -6.16194582, + -1.2852304}, + {-0.207204342, 3.74191999, 1.23634934, 2.39491701, 2.05387831, 8.58817196, + 3.65675569, 9.16720486, -5.8212862, 3.89707994, 13.4189224, -3.09973836, + 7.5796423, -0.365473986, -1.54334283, -5.30818748, 0.602919102}}, + {{-4.85392904, -2.36758995, -8.77992058, 3.50987387, -1.12358332, + -6.46516418, 3.44891453, -3.35269594, -6.95946836, -2.25799656, + 0.080966711, 3.76473641, -1.4134531, 3.168015, 1.69996285, -2.40649772, + -9.11525726}, + {9.77986431, 1.73628068, -9.28857327, -0.881102562, 2.03340697, + -2.93252277, -5.35455704, 1.34708834, -4.76539326, 1.6799016, 5.09027529, + -4.21229887, -2.32152724, -1.53899908, 6.4186182, -0.891803145, + 6.0681715}}, }; // Layer 0, QHead 0 const float kGoldenQ[kNumTokens][kQBatchSize][kDimsToCompare] = { - {{-0.574401975, 0.370210886, -0.426894158, -0.543187439, -0.0266762674, - -0.177960411, -0.00839618221, 0.411925405, 0.536462784, 0.528389931, - -0.499812007, -0.123897657, -0.0170236826, 0.266041577, -0.0781469196, - -0.44081074, 0.185976267}, - {0.270543516, -0.109283224, -0.58602041, -0.358663559, -0.393124342, - -0.0895933211, -0.632167816, 0.386703, 0.314152211, 0.0554139167, - 0.0241559595, -0.194484815, 0.143893063, 0.103837147, -0.384245932, - -0.00418212265, 0.385817379}}, - {{-0.0331106335, -0.100827977, 0.322449774, 0.225943685, -0.384854138, - -0.208085626, 0.0206767023, 0.287796348, -0.139513299, 0.255447835, - -0.0845065042, -0.0619940236, 0.477489054, 0.517492294, -0.0172665715, - -0.0302075297, 0.365989387}, - {-0.0266781822, -0.453293771, 0.560033202, 0.105156079, -0.35259968, - 0.711447716, -0.253611088, 0.0487165749, -0.086192511, -0.0338740349, - -0.655441046, 0.00413730741, -0.510472536, -0.0748229772, -0.29113093, - -0.0432077348, 0.09223634}}, - {{-0.321974993, -0.466039479, 0.207254037, -0.126807183, -0.192775592, - -0.0953654051, 0.209789664, 0.405356169, -0.00627984107, -0.0590961352, - 0.0907663852, -0.190793216, -0.730463982, 0.340142608, -0.295675993, - -0.165913597, -0.233714506}, - {-0.345578939, 0.394073665, 0.299743414, -0.0075177839, -0.288939595, - 0.127782941, -0.207550645, 0.0655022636, -0.705084503, -0.241842598, - 0.333820701, 0.217911497, 0.29735288, 0.0147881694, -0.152306199, - -0.589594781, -0.373093933}}, - {{0.216089666, 0.0918798149, 0.0560657382, -0.157523662, -0.00141695142, - 0.51770103, 0.596379519, -0.271057904, 0.241035417, -0.275827706, - 0.112851456, 0.026878573, -0.579843462, -0.5116328, 0.192026839, - 0.125176072, 0.34234497}, - {-0.0744233653, 0.180814236, 0.170143247, -0.337861449, -0.175804421, - 0.213403732, -0.173699334, 0.109528325, -0.385727316, 0.109683953, - 0.475667775, 0.253016889, 0.477347463, 0.111096457, 0.394625545, - 0.0172286481, -0.357992649}}, - {{-0.350524545, -0.142550975, -0.212269634, -0.0589753427, -0.434021264, - 0.384472728, 0.445421219, -0.635599554, -0.246593416, 0.120986834, - 0.623568773, -0.161932915, -0.702406883, 0.44038102, 0.268234134, - 0.480264157, 0.103595078}, - {-0.227436215, 0.357608706, -0.25339672, -0.0683218762, -0.179259315, - 0.23657614, 0.559984326, 0.165754288, -0.0402980596, -0.101906747, - -0.278261065, -0.16327399, 0.235923961, -0.428657919, -0.290629387, - 0.579215467, -0.0717103705}}, - {{-0.246389642, -0.266164362, -0.0967710763, -0.4011603, 0.242542207, - 0.0869855583, 0.20158039, 0.207793877, -0.0875666738, -0.242263764, - -0.0462955758, -0.617374003, 0.454443514, 0.207072973, -0.0235372931, - -0.0193868056, -0.660622239}, - {0.703284621, 0.0382430181, 0.43997851, -0.858277559, 0.342218578, - 0.414044619, 0.403636098, -0.579880178, -1.12243, -0.112913512, - 0.629238605, -0.0285760984, -0.152203664, -0.088969171, -0.0681343, - 0.476349175, 0.283238202}}, - {{0.138267457, 0.483219147, 0.230450034, -0.568304598, 0.204461277, - -0.286731184, -0.416590065, -0.483460307, -0.561008453, 0.395195067, - 0.104367018, -0.196090236, -0.324770749, -0.0881370157, -0.626873195, - 0.0936089084, 0.262185335}, - {0.282603383, 0.0723766163, -0.206548154, 0.561849833, 0.482716829, - 0.135281503, -0.438841999, 0.472577304, -0.346201897, -0.0211652666, - -0.0905084163, -0.168639392, -0.154975936, -0.303443581, -0.41771856, - 0.400717318, 0.426146686}}, - {{-0.0537007451, -0.227346331, -0.2871463, 0.247746795, -0.0975416005, - -0.0123391449, 0.0612513907, -0.374673814, 0.283457696, 0.40945363, - 0.137944818, -0.0119741419, 0.775918365, -0.308365196, 0.230615795, - -0.440364927, 0.218536288}, - {0.0688965544, -0.149037778, -0.246169299, 0.0599289536, -0.456733435, - 0.0808929354, 0.115154952, 0.0997388735, -0.408117741, 0.576600909, - -0.193775773, 0.0340575948, -0.29254055, 0.695465446, 0.373336494, - 0.421431482, 0.00197479129}}, - {{0.402076721, -0.118151993, 0.542394996, 0.0382412486, -0.614983976, - 0.28617692, 0.318540633, -0.299300969, -0.177486539, 0.394140214, - 0.0644133314, -0.0321308076, 0.671587527, -0.0173831787, -0.219400048, - -0.340277791, 0.5130288}, - {0.105372488, -0.145784974, 0.0695323348, -0.106080391, -0.755512118, - 0.975362539, -0.15056029, 0.58882606, -0.059625227, -0.810613, - -0.321623206, 0.193939567, 0.0340242684, -0.626081824, 0.109950632, - -0.141072854, 0.0177994221}}, - {{0.243249148, 0.0904035419, -0.472183734, -0.176162, 0.314925164, - -0.191137731, 0.492265761, -0.0120046511, 0.824757636, 0.298175, - 0.148151726, -0.0197859108, -0.64297086, 0.432318538, -0.555079758, - 0.101636633, 0.155741245}, - {0.0523641109, 0.224086404, 0.0143201668, 0.0090854, 0.304901183, - -0.391372293, 0.267655343, 0.117368169, 0.645064473, 0.336050332, - -0.282133281, -0.231817603, 0.376230389, -0.575031936, -0.628365576, - 0.484799922, 0.0824087635}}, + {{-0.374841154, -0.269048423, 0.324933857, 0.270255983, 0.192583397, + 0.0567071736, 0.250502706, 0.625115335, -0.403177321, 0.271447271, + 0.286808699, -0.0656447411, 0.276836812, 0.0164474752, 0.315540373, + 0.265531778, 0.143433452}, + {0.303192079, -0.0379101634, 0.154115498, -0.00872713327, -0.103512973, + -0.0887796879, -0.216018289, 0.607339799, 0.055648379, -0.191132426, + -0.319971651, -0.208316207, -0.264384329, -0.299360216, 0.0837299377, + -0.283533514, -0.501275897}}, + {{-0.114549503, -0.118767068, -0.456864387, 0.144393563, 0.0955479592, + -0.133590534, 0.444972277, 0.114303589, -0.0884202197, -0.0573218763, + -0.0792874247, 0.403315246, -0.278178513, -0.00494343042, -0.257657051, + 0.030698413, -0.0186916813}, + {-0.373288035, -0.215933442, -0.201702699, -0.114249617, -0.52541703, + 0.275511354, 0.335507631, 0.62828052, 0.248843148, 0.513091445, + -0.0282848328, -0.248418555, -0.522639215, -0.0390388519, 0.192302689, + -0.449831903, -0.179292724}}, + {{0.142575517, -0.237895951, 0.146644697, -0.503801346, -0.523338497, + 0.0719232783, 0.0608261451, 0.151101857, 0.02000916, -0.725266218, + 0.163600311, -0.02573248, 0.293753356, -0.450484604, 0.20146054, + 0.110477969, 0.354954362}, + {-0.239320278, 0.526096821, -0.286867231, -0.443862438, 0.735460579, + -0.245309472, 0.722944438, 0.0783652365, 0.21042797, 0.569268048, + -0.0406528264, 0.0399431735, -0.305004865, -0.137150392, -0.130049363, + 0.330584168, 0.0668990687}}, + {{-0.194874108, -0.205414161, -0.220138401, 0.0517282933, -0.161865696, + -0.233355582, -0.144200221, 0.535177469, 0.219330966, 0.217425376, + -0.13133359, 0.195236742, 0.257307261, 0.279794693, 0.384352505, + 0.174138933, 0.0952773392}, + {0.122517705, -0.532220542, 0.231840312, 0.421907842, -0.693262935, + 0.379204452, 0.904855072, -0.238233089, -0.0102335168, -0.385086507, + 0.0983751193, -0.0335776061, 0.00405130535, 0.363216281, -0.131849915, + 0.0302671418, -0.00287117064}}, + {{-0.136619762, -0.916439533, -0.250397354, -0.0263281856, -0.607887447, + -0.12422359, 0.0350730009, 0.0140353218, -0.156378835, 0.979060471, + 0.0746487826, -0.223096639, 0.0214309599, -0.226047188, 0.0714672953, + -0.405700892, -0.132313401}, + {0.439182878, 0.084455654, -0.776320815, -0.592856288, 0.365012228, + 0.185673609, -0.24275738, 0.275207847, -0.746165574, -0.256350815, + -0.481744856, 0.524834514, 0.152572945, 0.405694962, -0.279294074, + -0.619180143, 0.16503042}}, + {{0.307029665, -0.258573472, -0.497068763, 0.133658186, 0.112126596, + -0.13778466, -0.469314516, -0.144993082, 0.341157258, -0.223292619, + 0.338864386, -0.165094376, 0.317748159, 0.131249368, -0.310955763, + -0.141406, -0.618950605}, + {-0.405226409, -0.289102376, 0.0477564782, -0.149198949, -0.424721092, + -0.113134548, -0.0732265264, -0.341526538, 0.124277025, -0.260352641, + -0.0306069255, 0.385291427, -0.279991835, 0.135148734, 0.251948118, + 0.0279652774, -0.0242935997}}, + {{0.123339117, 0.112210952, -0.423181385, 0.112272829, -0.279016107, + 0.307293028, 0.613147676, 0.00073248148, 0.819842041, 0.0347603858, + -0.0396398082, 0.074497737, -0.0331122801, -0.205312088, 0.954650819, + -0.284037501, -0.17986232}, + {0.127260983, -0.184656262, -0.257579148, 0.214763999, 0.4361099, + -0.0158195253, 0.0339632668, -0.133950815, 0.204951435, -0.247553974, + 0.739190161, -0.0878294855, -0.127532601, -0.549639583, 0.254371703, + 0.0851583332, 0.307077497}}, + {{-0.0720033944, -0.230760068, 0.204314083, -0.346839815, -0.0487727225, + 0.151570067, 0.710862041, -0.4089351, 0.300317228, 0.571746171, + -0.546940625, 0.0928032696, 0.0187496543, 0.29309383, -0.322793603, + -0.186359257, -0.550192237}, + {-0.333711773, 0.250101328, -0.538163781, -0.436006278, 0.247505322, + 0.279933214, -0.259696215, 0.0872357413, 0.333090097, 0.950338364, + -0.110226423, -0.253991336, -0.194895253, 0.336680681, 0.175827622, + 0.184941083, 0.565679312}}, + {{0.492006898, 0.106031463, -0.0973178521, -0.214457124, -0.0938223451, + 0.202232271, 0.293491513, -0.319558859, 0.0366688259, -0.044666674, + -0.523907304, 0.401466191, 0.0948085636, -0.665217042, 0.0531942286, + -0.707738578, -0.155400679}, + {-0.309382081, 0.238702834, -0.154397696, 0.153635919, 0.0586032122, + -0.356307834, -0.242223755, 0.211881027, 0.686982214, 0.361260235, + -0.487024903, -0.181656718, -0.104096822, -0.0305453707, 0.331899464, + 0.0255006049, -0.826909781}}, + {{0.0855419636, -0.325473666, -0.378067434, 0.599543989, -0.115204476, + -0.479211658, -0.0426419526, 0.0785699934, -0.409276605, 0.028221447, + 0.0391969681, 0.428700686, -0.132882744, -0.173993275, 0.697183192, + 0.160488009, 0.611800015}, + {0.177823097, 0.604698062, 0.917836607, 0.250253111, -0.775083899, + 0.308443069, 0.194380283, -0.572413027, -0.286389142, -0.382753521, + 0.0876774341, 0.0594621263, -0.192462415, -0.0088978298, -0.449309558, + 0.139618352, 0.164170146}}, }; void RunAttentionTest(AttentionImpl attention_impl) {