Merge branch 'master' into HEAD
This commit is contained in:
commit
0086c246ee
|
|
@ -10,44 +10,44 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,
|
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,
|
||||||
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
|
||||||
const uint y_idx_base = i * QUANT_K + 32 * ib32;
|
const uint y_idx_base = i * QUANT_K + 32 * ib32;
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;
|
const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;
|
||||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||||
const vec4 b_val_0 = vec4(data_b_v4[base_b_idx + 2 * l]);
|
const vec4 b_val_0 = vec4(data_b_v4[base_b_idx + 2 * l]);
|
||||||
const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
|
const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
|
||||||
|
|
||||||
// index for data_a
|
// index for data_a
|
||||||
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const float d = float(data_a[ibi].d);
|
const float d = float(data_a[ibi].d);
|
||||||
const uint qh = data_a[ibi].qh[ib32];
|
const uint qh = data_a[ibi].qh[ib32];
|
||||||
|
|
||||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||||
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
const uint qs = data_a[ibi].qs[4 * ib32 + l];
|
||||||
const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
|
const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
|
||||||
const uint16_t grid = uint16_t(iq1s_grid[qs | (idxhi << 8)]);
|
const uint16_t grid = uint16_t(iq1s_grid[qs | (idxhi << 8)]);
|
||||||
|
|
||||||
const float delta_val = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
const float delta_val = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
const vec4 delta_v = vec4(delta_val);
|
const vec4 delta_v = vec4(delta_val);
|
||||||
const vec4 fbits0 = vec4(
|
const vec4 fbits0 = vec4(
|
||||||
float(bitfieldExtract(grid, 0, 2)),
|
float(bitfieldExtract(grid, 0, 2)),
|
||||||
float(bitfieldExtract(grid, 2, 2)),
|
float(bitfieldExtract(grid, 2, 2)),
|
||||||
float(bitfieldExtract(grid, 4, 2)),
|
float(bitfieldExtract(grid, 4, 2)),
|
||||||
float(bitfieldExtract(grid, 6, 2))
|
float(bitfieldExtract(grid, 6, 2))
|
||||||
);
|
);
|
||||||
const vec4 fbits1 = vec4(
|
const vec4 fbits1 = vec4(
|
||||||
float(bitfieldExtract(grid, 8, 2)),
|
float(bitfieldExtract(grid, 8, 2)),
|
||||||
float(bitfieldExtract(grid, 10, 2)),
|
float(bitfieldExtract(grid, 10, 2)),
|
||||||
float(bitfieldExtract(grid, 12, 2)),
|
float(bitfieldExtract(grid, 12, 2)),
|
||||||
float(bitfieldExtract(grid, 14, 2))
|
float(bitfieldExtract(grid, 14, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
vec4 sum_v = fma(b_val_0, fbits0 + delta_v, vec4(0.0));
|
vec4 sum_v = fma(b_val_0, fbits0 + delta_v, vec4(0.0));
|
||||||
sum_v = fma(b_val_1, fbits1 + delta_v, sum_v);
|
sum_v = fma(b_val_1, fbits1 + delta_v, sum_v);
|
||||||
FLOAT_TYPE sum = dot(sum_v, vec4(1.0));
|
FLOAT_TYPE sum = dot(sum_v, vec4(1.0));
|
||||||
|
|
||||||
temp[j][n] = fma(dl, sum, temp[j][n]);
|
temp[j][n] = fma(dl, sum, temp[j][n]);
|
||||||
ibi += num_blocks_per_row;
|
ibi += num_blocks_per_row;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
const float pos = ubatch->pos[i];
|
const float pos = ubatch->pos[i];
|
||||||
attn_scale_data[i] = std::log(
|
attn_scale_data[i] = std::log(
|
||||||
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
|
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
|
||||||
) * f_attn_temp_scale + 1.0;
|
) * f_attn_temp_scale + 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1278,7 +1278,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
|
||||||
|
|
||||||
auto & cur = inp->attn_scale;
|
auto & cur = inp->attn_scale;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,8 +133,8 @@ public:
|
||||||
// temperature tuning, used by llama4
|
// temperature tuning, used by llama4
|
||||||
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
|
||||||
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
|
||||||
virtual ~llm_graph_input_attn_temp() = default;
|
virtual ~llm_graph_input_attn_temp() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
@ -143,6 +143,7 @@ public:
|
||||||
|
|
||||||
const uint32_t n_attn_temp_floor_scale;
|
const uint32_t n_attn_temp_floor_scale;
|
||||||
const float f_attn_temp_scale;
|
const float f_attn_temp_scale;
|
||||||
|
const float f_attn_temp_offset;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,7 @@ struct llama_hparams {
|
||||||
uint32_t n_no_rope_layer_step = 4;
|
uint32_t n_no_rope_layer_step = 4;
|
||||||
uint32_t n_attn_temp_floor_scale = 0;
|
uint32_t n_attn_temp_floor_scale = 0;
|
||||||
float f_attn_temp_scale = 0.0f;
|
float f_attn_temp_scale = 0.0f;
|
||||||
|
float f_attn_temp_offset = 0.0f; // offset position index
|
||||||
|
|
||||||
// gemma3n altup
|
// gemma3n altup
|
||||||
uint32_t n_altup = 4; // altup_num_inputs
|
uint32_t n_altup = 4; // altup_num_inputs
|
||||||
|
|
|
||||||
|
|
@ -668,6 +668,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
hparams.n_swa = 8192;
|
hparams.n_swa = 8192;
|
||||||
hparams.n_attn_temp_floor_scale = 8192;
|
hparams.n_attn_temp_floor_scale = 8192;
|
||||||
hparams.f_attn_temp_scale = 0.1f;
|
hparams.f_attn_temp_scale = 0.1f;
|
||||||
|
hparams.f_attn_temp_offset = 1.0f;
|
||||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1646,6 +1647,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
|
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
|
||||||
|
|
||||||
|
hparams.f_attn_temp_offset = 0.0f;
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 27: type = LLM_TYPE_16B; break;
|
case 27: type = LLM_TYPE_16B; break;
|
||||||
case 60: type = LLM_TYPE_236B; break;
|
case 60: type = LLM_TYPE_236B; break;
|
||||||
|
|
@ -2276,6 +2279,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
|
||||||
|
|
||||||
|
hparams.f_attn_temp_offset = 0.0f;
|
||||||
|
|
||||||
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
||||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||||
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
|
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue