Increase scope of embedding cli tests
This commit is contained in:
parent
c1c3d99ef6
commit
2de1e6871f
|
|
@ -4,19 +4,25 @@ name: Embedding CLI
|
|||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- feature/*
|
||||
- master
|
||||
branches: [master, feature/**]
|
||||
paths:
|
||||
- '.github/workflows/embeddings.yml'
|
||||
- 'examples/embedding/**'
|
||||
- 'examples/tests/**'
|
||||
- '.github/workflows/embedding.yml'
|
||||
- 'examples/**'
|
||||
- 'src/**'
|
||||
- 'ggml/**'
|
||||
- 'include/**'
|
||||
- '**/CMakeLists.txt'
|
||||
- 'tests/e2e/embedding/**'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths:
|
||||
- '.github/workflows/embeddings.yml'
|
||||
- 'examples/embedding/**'
|
||||
- 'examples/tests/**'
|
||||
- '.github/workflows/embedding.yml'
|
||||
- 'examples/**'
|
||||
- 'src/**'
|
||||
- 'ggml/**'
|
||||
- 'include/**'
|
||||
- '**/CMakeLists.txt'
|
||||
- 'tests/e2e/embedding/**'
|
||||
|
||||
jobs:
|
||||
embedding-cli-tests:
|
||||
|
|
@ -56,4 +62,4 @@ jobs:
|
|||
|
||||
- name: Run embedding tests
|
||||
run: |
|
||||
pytest -v examples/tests
|
||||
pytest -v tests/e2e/embedding
|
||||
|
|
@ -1,14 +1,17 @@
|
|||
import os, json, subprocess, hashlib
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import pytest
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EPS = 1e-3
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
EXE = REPO_ROOT / ("build/bin/llama-embedding.exe" if os.name == "nt" else "build/bin/llama-embedding")
|
||||
DEFAULT_ENV = {**os.environ, "LLAMA_CACHE": os.environ.get("LLAMA_CACHE", "tmp")}
|
||||
SEED = "42"
|
||||
|
|
@ -96,6 +99,7 @@ def embedding_hash(vec: np.ndarray) -> str:
|
|||
# Register custom mark so pytest doesn't warn about it
|
||||
pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnknownMarkWarning")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("fmt", ["raw", "json"])
|
||||
@pytest.mark.parametrize("text", ["hello world", "hi 🌎", "line1\nline2\nline3"])
|
||||
Loading…
Reference in New Issue