fixed flake8 complaints locally
This commit is contained in:
parent
ac85cb1375
commit
4faf26c376
|
|
@ -5146,11 +5146,11 @@ class KimiLinearModel(TextModel):
|
|||
_num_kv_heads = list()
|
||||
_full_attn_layers = linear_attn_config["full_attn_layers"]
|
||||
for il in range(self.hparams["num_hidden_layers"]):
|
||||
if il+1 in _full_attn_layers:
|
||||
if il + 1 in _full_attn_layers:
|
||||
_num_kv_heads.append(self.hparams["num_key_value_heads"])
|
||||
else:
|
||||
_num_kv_heads.append(0)
|
||||
assert(len(_num_kv_heads) == self.hparams["num_hidden_layers"])
|
||||
assert len(_num_kv_heads) == self.hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_head_count_kv(_num_kv_heads)
|
||||
|
||||
ssm_d_conv = self.hparams.get("ssm_d_conv") or linear_attn_config.get("short_conv_kernel_size")
|
||||
|
|
@ -5328,6 +5328,7 @@ class KimiLinearModel(TextModel):
|
|||
logger.info(f"Returning {mapped_name}: shape after = {tuple(data_torch.shape)}")
|
||||
return [(mapped_name, data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.INTERNLM2
|
||||
|
|
|
|||
Loading…
Reference in New Issue