252 lines
8.8 KiB
Python
252 lines
8.8 KiB
Python
import json
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import pytest
|
|
import subprocess
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import time
|
|
from typing import Optional, List
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
EPS = 1e-3
|
|
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
EXE = REPO_ROOT / ("build/bin/llama-embedding.exe" if os.name == "nt" else "build/bin/llama-embedding")
|
|
DEFAULT_ENV = {**os.environ, "LLAMA_CACHE": os.environ.get("LLAMA_CACHE", "tmp")}
|
|
SEED = "42"
|
|
ALLOWED_DIMS = {384, 768, 1024, 4096}
|
|
|
|
SMALL_CTX = 16 # preflight/cache
|
|
TEST_CTX = 1024 # main tests
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared helpers (single source of truth for command building)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def resolve_exe() -> Path:
|
|
exe = EXE
|
|
if not exe.exists() and os.name == "nt":
|
|
alt = REPO_ROOT / "build/bin/Release/llama-embedding.exe"
|
|
if alt.exists():
|
|
exe = alt
|
|
if not exe.exists():
|
|
raise FileNotFoundError(f"llama-embedding not found under {REPO_ROOT}/build/bin")
|
|
return exe
|
|
|
|
|
|
def hf_params_default():
|
|
return {
|
|
"hf_repo": "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF",
|
|
"hf_file": "embeddinggemma-300M-qat-Q4_0.gguf",
|
|
}
|
|
|
|
|
|
def build_cmd(
|
|
*,
|
|
exe: Path,
|
|
params: dict,
|
|
fmt: str,
|
|
threads: int,
|
|
ctx: int,
|
|
seed: str,
|
|
extra: Optional[List[str]] = None, # was: list[str] | None
|
|
) -> List[str]: # was: list[str]
|
|
assert fmt in {"raw", "json"}, f"unsupported fmt={fmt}"
|
|
cmd = [
|
|
str(exe),
|
|
"-hfr", params["hf_repo"],
|
|
"-hff", params["hf_file"],
|
|
"--ctx-size", str(ctx),
|
|
"--embd-output-format", fmt,
|
|
"--threads", str(threads),
|
|
"--seed", seed,
|
|
]
|
|
if extra:
|
|
cmd.extend(extra)
|
|
return cmd
|
|
|
|
|
|
def run_cmd(cmd: list[str], text: str, timeout: int = 60) -> str:
|
|
t0 = time.perf_counter()
|
|
res = subprocess.run(cmd, input=text, capture_output=True, text=True,
|
|
env=DEFAULT_ENV, timeout=timeout)
|
|
dur_ms = (time.perf_counter() - t0) * 1000.0
|
|
if os.environ.get("EMBD_TEST_DEBUG") == "1":
|
|
log.debug("embedding cmd finished in %.1f ms", dur_ms)
|
|
|
|
if res.returncode != 0:
|
|
raise AssertionError(f"embedding failed ({res.returncode}):\n{res.stderr[:400]}")
|
|
out = res.stdout.strip()
|
|
assert out, "empty stdout from llama-embedding"
|
|
return out
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session model preflight/cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def embedding_model():
|
|
"""Download/cache model once per session with a tiny ctx + no warmup."""
|
|
exe = resolve_exe()
|
|
params = hf_params_default()
|
|
cmd = build_cmd(
|
|
exe=exe, params=params, fmt="json",
|
|
threads=1, ctx=SMALL_CTX, seed=SEED,
|
|
extra=["--no-warmup"],
|
|
)
|
|
_ = run_cmd(cmd, text="ok")
|
|
return params
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utility functions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def run_embedding(
|
|
text: str,
|
|
*,
|
|
fmt: str = "raw",
|
|
threads: int = 1,
|
|
ctx: int = TEST_CTX,
|
|
params: Optional[dict] = None, # was: dict | None
|
|
timeout: int = 60,
|
|
) -> str:
|
|
exe = resolve_exe()
|
|
params = params or hf_params_default()
|
|
cmd = build_cmd(exe=exe, params=params, fmt=fmt, threads=threads, ctx=ctx, seed=SEED)
|
|
return run_cmd(cmd, text, timeout=timeout)
|
|
|
|
|
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
|
|
|
|
|
def embedding_hash(vec: np.ndarray) -> str:
|
|
"""Return short deterministic signature for regression tracking."""
|
|
return hashlib.sha256(vec[:8].tobytes()).hexdigest()[:16]
|
|
|
|
|
|
def parse_vec(out: str, fmt: str) -> np.ndarray:
|
|
if fmt == "raw":
|
|
arr = np.array(out.split(), dtype=np.float32)
|
|
else:
|
|
arr = np.array(json.loads(out)["data"][0]["embedding"], dtype=np.float32)
|
|
return arr
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
# Register custom mark so pytest doesn't warn about it
|
|
pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnknownMarkWarning")
|
|
|
|
|
|
@pytest.mark.parametrize("fmt", ["raw", "json"])
|
|
@pytest.mark.parametrize("text", ["hello world", "hi 🌎", "line1\nline2\nline3"])
|
|
def test_embedding_runs_and_finite(fmt, text, embedding_model):
|
|
out = run_embedding(text, fmt=fmt, threads=1, ctx=TEST_CTX, params=embedding_model)
|
|
vec = parse_vec(out, fmt)
|
|
assert vec.dtype == np.float32
|
|
# dim & finiteness
|
|
assert len(vec) in ALLOWED_DIMS, f"unexpected dim={len(vec)}"
|
|
assert np.all(np.isfinite(vec))
|
|
assert 0.1 < np.linalg.norm(vec) < 10
|
|
|
|
|
|
def test_raw_vs_json_consistency(embedding_model):
|
|
text = "hello world"
|
|
raw = parse_vec(run_embedding(text, fmt="raw", params=embedding_model), "raw")
|
|
jsn = parse_vec(run_embedding(text, fmt="json", params=embedding_model), "json")
|
|
assert raw.shape == jsn.shape
|
|
cos = cosine_similarity(raw, jsn)
|
|
assert cos > 0.999, f"raw/json divergence: cos={cos:.6f}"
|
|
assert embedding_hash(raw) == embedding_hash(jsn)
|
|
|
|
|
|
def test_empty_input_deterministic(embedding_model):
|
|
v1 = parse_vec(run_embedding("", fmt="raw", params=embedding_model), "raw")
|
|
v2 = parse_vec(run_embedding("", fmt="raw", params=embedding_model), "raw")
|
|
assert np.all(np.isfinite(v1))
|
|
assert embedding_hash(v1) == embedding_hash(v2)
|
|
assert cosine_similarity(v1, v2) > 0.99999
|
|
|
|
|
|
def test_very_long_input_stress(embedding_model):
|
|
"""Stress test: large input near context window."""
|
|
text = "lorem " * 2000
|
|
vec = parse_vec(run_embedding(text, fmt="raw", params=embedding_model), "raw")
|
|
assert len(vec) in ALLOWED_DIMS
|
|
assert np.isfinite(np.linalg.norm(vec))
|
|
|
|
|
|
@pytest.mark.parametrize("text", [" ", "\n\n\n", "123 456 789"])
|
|
def test_low_information_inputs_stable(text, embedding_model):
|
|
"""Whitespace/numeric inputs should yield stable embeddings."""
|
|
v1 = parse_vec(run_embedding(text, fmt="raw", params=embedding_model), "raw")
|
|
v2 = parse_vec(run_embedding(text, fmt="raw", params=embedding_model), "raw")
|
|
cos = cosine_similarity(v1, v2)
|
|
assert cos > 0.999, f"unstable embedding for {text!r}"
|
|
|
|
|
|
@pytest.mark.parametrize("flag", ["--no-such-flag", "--help"])
|
|
def test_invalid_or_help_flag(flag):
|
|
"""Invalid flags should fail; help should succeed."""
|
|
exe = resolve_exe()
|
|
res = subprocess.run([str(exe), flag], capture_output=True, text=True, env=DEFAULT_ENV)
|
|
if flag == "--no-such-flag":
|
|
assert res.returncode != 0
|
|
assert any(k in res.stderr.lower() for k in ("error", "invalid", "unknown"))
|
|
else:
|
|
assert res.returncode == 0
|
|
assert "usage" in (res.stdout.lower() + res.stderr.lower())
|
|
|
|
|
|
@pytest.mark.parametrize("fmt", ["raw", "json"])
|
|
def test_threads_two_similarity_vs_single(fmt, embedding_model):
|
|
text = "determinism vs threads"
|
|
single = parse_vec(run_embedding(text, fmt=fmt, threads=1, params=embedding_model), fmt)
|
|
multi = parse_vec(run_embedding(text, fmt=fmt, threads=2, params=embedding_model), fmt)
|
|
assert single.shape == multi.shape
|
|
cos = cosine_similarity(single, multi)
|
|
assert cos >= 0.999, f"threads>1 similarity too low: {cos:.6f}"
|
|
|
|
|
|
def test_json_shape_schema_minimal(embedding_model):
|
|
js = json.loads(run_embedding("schema check", fmt="json", params=embedding_model))
|
|
assert isinstance(js, dict)
|
|
|
|
# Top-level “object” (present in CLI) is optional for us
|
|
if "object" in js:
|
|
assert js["object"] in ("list", "embeddings", "embedding_list")
|
|
|
|
# Required: data[0].embedding + index
|
|
assert "data" in js and isinstance(js["data"], list) and len(js["data"]) >= 1
|
|
item0 = js["data"][0]
|
|
assert isinstance(item0, dict)
|
|
if "object" in item0:
|
|
assert item0["object"] in ("embedding",)
|
|
assert "index" in item0 and item0["index"] == 0
|
|
assert "embedding" in item0 and isinstance(item0["embedding"], list)
|
|
assert len(item0["embedding"]) in ALLOWED_DIMS
|
|
|
|
# Optional fields: tolerate absence in current CLI output
|
|
if "model" in js:
|
|
assert isinstance(js["model"], str)
|
|
if "dim" in js:
|
|
assert js["dim"] == len(item0["embedding"])
|
|
usage = js.get("usage", {})
|
|
if usage:
|
|
assert isinstance(usage, dict)
|
|
# if present, prompt_tokens should be int
|
|
if "prompt_tokens" in usage:
|
|
assert isinstance(usage["prompt_tokens"], int)
|