diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7ff14cb88c..b76a0ab596 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -122,11 +122,11 @@ class GGUFWriter: self.path = path if self.path is not None: - self.print_plan() - self.fout = [open(filename, "wb") for filename in self.format_shard_names(self.path)] + filenames = self.print_plan() + self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY - def print_plan(self) -> None: + def print_plan(self) -> list[Path]: logger.info("Writing the following files:") assert self.path is not None filenames = self.format_shard_names(self.path) @@ -138,6 +138,8 @@ class GGUFWriter: logger.info("Dry run, not writing files") exit() + return filenames + def add_shard_kv_data(self) -> None: if len(self.tensors) == 1: return @@ -152,7 +154,7 @@ class GGUFWriter: self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) def write_header_to_file(self, path: Path | None = None) -> None: - if len(self.tensors) == 1: + if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0): logger.warning("Model fails split requirements, not splitting") self.open_output_file(path) @@ -298,13 +300,15 @@ class GGUFWriter: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) # make sure there is at least one tensor before splitting - if (len(self.tensors[-1]) > 0 - # split when over tensor limit - and (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors) - # or split when over size limit - or (self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size)): - - self.tensors.append(dict()) + if len(self.tensors[-1]) > 0: + if ( # split when over tensor limit + self.split_max_tensors != 0 + and len(self.tensors[-1]) >= self.split_max_tensors + ) or ( # split when over size limit + self.split_max_size != 0 + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size + ): + self.tensors.append({}) self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) @@ -367,12 +371,12 @@ class GGUFWriter: total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - shard_bar = tqdm(desc="Shard progress", total=total_bytes, unit="byte", unit_scale=True) + if len(self.fout) > 1: + shard_bar = tqdm(desc=f"Shard (1/{len(self.fout)})", total=total_bytes, unit="byte", unit_scale=True) for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): - if bar and len(self.fout) > 1: - bar.desc = f"Writing ({i + 1}/{len(self.fout)})" if shard_bar and len(self.fout) > 1: + shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") total = sum(ti.nbytes for ti in tensors.values()) # bar behaves weirdly when total is 0 if total > 0: @@ -686,9 +690,6 @@ class GGUFWriter: return kv_data - def _write_packed(self, fout: BufferedWriter, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: - fout.write(self._pack(fmt, value, skip_pack_prefix)) - @staticmethod def format_n_bytes_to_str(num: int) -> str: if num == 0: