fix(filter): enforce CEL syntax semantics

Reject non-standard truthy numeric expressions in filters and document the parser as a supported subset of standard CEL syntax.

- remove legacy filter rewrites
- support standard equality in tag exists predicates
- add regression coverage for accepted and rejected expressions
This commit is contained in:
boojack 2026-03-31 08:10:49 +08:00
parent d3f6e8ee31
commit 0e89407ee9
11 changed files with 141 additions and 89 deletions

View File

@ -1,12 +1,14 @@
# Memo Filter Engine
This package houses the memo-only filter engine that turns CEL expressions into
SQL fragments. The engine follows a three phase pipeline inspired by systems
This package houses the memo-only filter engine that turns standard CEL syntax
into SQL fragments for the subset of expressions supported by the memo schema.
The engine follows a three phase pipeline inspired by systems
such as Calcite or Prisma:
1. **Parsing** CEL expressions are parsed with `cel-go` and validated against
the memo-specific environment declared in `schema.go`. Only fields that
exist in the schema can surface in the filter.
exist in the schema can surface in the filter, and non-standard legacy
coercions are rejected.
2. **Normalization** the raw CEL AST is converted into an intermediate
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
conditions (logical operators, comparisons, list membership, etc.). This

View File

@ -2,7 +2,6 @@ package filter
import (
"context"
"fmt"
"strings"
"sync"
@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
return nil, errors.New("filter expression is empty")
}
filter = normalizeLegacyFilter(filter)
ast, issues := e.env.Compile(filter)
if issues != nil && issues.Err() != nil {
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) {
})
return defaultAttachmentInst, defaultAttachmentErr
}
func normalizeLegacyFilter(expr string) string {
expr = rewriteNumericLogicalOperand(expr, "&&")
expr = rewriteNumericLogicalOperand(expr, "||")
return expr
}
func rewriteNumericLogicalOperand(expr, op string) string {
var builder strings.Builder
n := len(expr)
i := 0
var inQuote rune
for i < n {
ch := expr[i]
if inQuote != 0 {
builder.WriteByte(ch)
if ch == '\\' && i+1 < n {
builder.WriteByte(expr[i+1])
i += 2
continue
}
if ch == byte(inQuote) {
inQuote = 0
}
i++
continue
}
if ch == '\'' || ch == '"' {
inQuote = rune(ch)
builder.WriteByte(ch)
i++
continue
}
if strings.HasPrefix(expr[i:], op) {
builder.WriteString(op)
i += len(op)
// Preserve whitespace following the operator.
wsStart := i
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
i++
}
builder.WriteString(expr[wsStart:i])
signStart := i
if i < n && (expr[i] == '+' || expr[i] == '-') {
i++
}
for i < n && expr[i] >= '0' && expr[i] <= '9' {
i++
}
if i > signStart {
numLiteral := expr[signStart:i]
fmt.Fprintf(&builder, "(%s != 0)", numLiteral)
} else {
builder.WriteString(expr[signStart:i])
}
continue
}
builder.WriteByte(ch)
i++
}
return builder.String()
}

View File

@ -0,0 +1,39 @@
package filter
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) {
t.Parallel()
engine, err := NewEngine(NewSchema())
require.NoError(t, err)
_, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`)
require.NoError(t, err)
}
func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) {
t.Parallel()
engine, err := NewEngine(NewSchema())
require.NoError(t, err)
_, err = engine.Compile(context.Background(), `pinned && 1`)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to compile filter")
}
func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) {
t.Parallel()
engine, err := NewEngine(NewSchema())
require.NoError(t, err)
_, err = engine.Compile(context.Background(), `1`)
require.EqualError(t, err, "filter must evaluate to a boolean value")
}

View File

@ -157,3 +157,10 @@ type ContainsPredicate struct {
}
func (*ContainsPredicate) isPredicateExpr() {}
// EqualsPredicate represents t == "value".
type EqualsPredicate struct {
Value string
}
func (*EqualsPredicate) isPredicateExpr() {}

View File

@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
if err != nil {
return nil, err
}
switch v := val.(type) {
case bool:
if v, ok := val.(bool); ok {
return &ConstantCondition{Value: v}, nil
case int64:
return &ConstantCondition{Value: v != 0}, nil
case float64:
return &ConstantCondition{Value: v != 0}, nil
default:
return nil, errors.New("filter must evaluate to a boolean value")
}
return nil, errors.New("filter must evaluate to a boolean value")
case *exprv1.Expr_IdentExpr:
name := v.IdentExpr.GetName()
field, ok := schema.Field(name)
@ -504,6 +498,8 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
// Handle different predicate functions
switch predicateCall.Function {
case "_==_":
return buildEqualsPredicate(predicateCall, comp.IterVar)
case "startsWith":
return buildStartsWithPredicate(predicateCall, comp.IterVar)
case "endsWith":
@ -511,10 +507,44 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
case "contains":
return buildContainsPredicate(predicateCall, comp.IterVar)
default:
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function)
}
}
// buildEqualsPredicate extracts the value from t == "value".
func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if len(call.Args) != 2 {
return nil, errors.New("equality predicate expects exactly two arguments")
}
var constExpr *exprv1.Expr
switch {
case isIterVarExpr(call.Args[0], iterVar):
constExpr = call.Args[1]
case isIterVarExpr(call.Args[1], iterVar):
constExpr = call.Args[0]
default:
return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar)
}
value, err := getConstValue(constExpr)
if err != nil {
return nil, errors.Wrap(err, "equality argument must be a constant string")
}
valueStr, ok := value.(string)
if !ok {
return nil, errors.New("equality argument must be a string")
}
return &EqualsPredicate{Value: valueStr}, nil
}
func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool {
target := expr.GetIdentExpr()
return target != nil && target.GetName() == iterVar
}
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
// Verify the target is the iteration variable

View File

@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
// Render based on predicate type
switch pred := cond.Predicate.(type) {
case *EqualsPredicate:
return r.renderTagEquals(field, pred.Value, cond.Kind)
case *StartsWithPredicate:
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
case *EndsWithPredicate:
@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
}
}
// renderTagEquals generates SQL for tags.exists(t, t == "value").
func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite, DialectMySQL:
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value))
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
case DialectPostgres:
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value)))
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)

View File

@ -48,7 +48,7 @@ For Streamable HTTP safety, requests with an `Origin` header must be same-origin
| Tool | Description | Required params | Optional params |
|---|---|---|---|
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) |
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) |
| `get_memo` | Get a single memo | `name` | — |
| `search_memos` | Full-text search | `query` | — |
| `create_memo` | Create a memo | `content` | `visibility` |

View File

@ -41,6 +41,10 @@ func checkMemoOwnership(memo *store.Memo, userID int32) error {
return nil
}
func hasMemoOwnership(memo *store.Memo, userID int32) bool {
return memo.CreatorID == userID
}
// applyVisibilityFilter restricts find to memos the caller may see.
func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) {
if rowStatus != nil && *rowStatus == store.Archived {

View File

@ -223,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
),
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
), s.handleListMemos)
mcpSrv.AddTool(mcp.NewTool("get_memo",

View File

@ -142,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
if srcMemo == nil {
return mcp.NewToolResultError("source memo not found"), nil
}
if err := checkMemoOwnership(srcMemo, userID); err != nil {
if !hasMemoOwnership(srcMemo, userID) {
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
}
@ -199,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
if srcMemo == nil {
return mcp.NewToolResultError("source memo not found"), nil
}
if err := checkMemoOwnership(srcMemo, userID); err != nil {
if !hasMemoOwnership(srcMemo, userID) {
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
}

View File

@ -730,6 +730,31 @@ func TestMemoFilterTagsExistsContains(t *testing.T) {
require.Len(t, memos, 1, "Should find 1 non-todo memo")
}
func TestMemoFilterTagsExistsEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-1231", tc.User.ID).
Content("Memo with exact numeric tag").
Tags("1231", "project"))
tc.CreateMemo(NewMemoBuilder("memo-1231-suffix", tc.User.ID).
Content("Memo with related tag").
Tags("tag/1231", "other"))
tc.CreateMemo(NewMemoBuilder("memo-other", tc.User.ID).
Content("Memo with different tag").
Tags("9999"))
memos := tc.ListWithFilter(`tags.exists(t, t == "1231")`)
require.Len(t, memos, 1, "Should find only the memo with exact matching tag")
require.Equal(t, "memo-1231", memos[0].UID)
memos = tc.ListWithFilter(`!tags.exists(t, t == "1231")`)
require.Len(t, memos, 2, "Should exclude only the memo with exact matching tag")
}
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)