From 0e89407ee91deda87e0df7464a89d26d1e9a88b3 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 31 Mar 2026 08:10:49 +0800 Subject: [PATCH] fix(filter): enforce CEL syntax semantics Reject non-standard truthy numeric expressions in filters and document the parser as a supported subset of standard CEL syntax. - remove legacy filter rewrites - support standard equality in tag exists predicates - add regression coverage for accepted and rejected expressions --- plugin/filter/README.md | 8 ++-- plugin/filter/engine.go | 73 ----------------------------- plugin/filter/engine_test.go | 39 +++++++++++++++ plugin/filter/ir.go | 7 +++ plugin/filter/parser.go | 48 +++++++++++++++---- plugin/filter/render.go | 18 +++++++ server/router/mcp/README.md | 2 +- server/router/mcp/access.go | 4 ++ server/router/mcp/tools_memo.go | 2 +- server/router/mcp/tools_relation.go | 4 +- store/test/memo_filter_test.go | 25 ++++++++++ 11 files changed, 141 insertions(+), 89 deletions(-) create mode 100644 plugin/filter/engine_test.go diff --git a/plugin/filter/README.md b/plugin/filter/README.md index 35961615f..ac1aec4b6 100644 --- a/plugin/filter/README.md +++ b/plugin/filter/README.md @@ -1,12 +1,14 @@ # Memo Filter Engine -This package houses the memo-only filter engine that turns CEL expressions into -SQL fragments. The engine follows a three phase pipeline inspired by systems +This package houses the memo-only filter engine that turns standard CEL syntax +into SQL fragments for the subset of expressions supported by the memo schema. +The engine follows a three phase pipeline inspired by systems such as Calcite or Prisma: 1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against the memo-specific environment declared in `schema.go`. Only fields that - exist in the schema can surface in the filter. + exist in the schema can surface in the filter, and non-standard legacy + coercions are rejected. 2. **Normalization** – the raw CEL AST is converted into an intermediate representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of conditions (logical operators, comparisons, list membership, etc.). This diff --git a/plugin/filter/engine.go b/plugin/filter/engine.go index c9fcfba7f..9dab7a0ba 100644 --- a/plugin/filter/engine.go +++ b/plugin/filter/engine.go @@ -2,7 +2,6 @@ package filter import ( "context" - "fmt" "strings" "sync" @@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) { return nil, errors.New("filter expression is empty") } - filter = normalizeLegacyFilter(filter) - ast, issues := e.env.Compile(filter) if issues != nil && issues.Err() != nil { return nil, errors.Wrap(issues.Err(), "failed to compile filter") @@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) { }) return defaultAttachmentInst, defaultAttachmentErr } - -func normalizeLegacyFilter(expr string) string { - expr = rewriteNumericLogicalOperand(expr, "&&") - expr = rewriteNumericLogicalOperand(expr, "||") - return expr -} - -func rewriteNumericLogicalOperand(expr, op string) string { - var builder strings.Builder - n := len(expr) - i := 0 - var inQuote rune - - for i < n { - ch := expr[i] - - if inQuote != 0 { - builder.WriteByte(ch) - if ch == '\\' && i+1 < n { - builder.WriteByte(expr[i+1]) - i += 2 - continue - } - if ch == byte(inQuote) { - inQuote = 0 - } - i++ - continue - } - - if ch == '\'' || ch == '"' { - inQuote = rune(ch) - builder.WriteByte(ch) - i++ - continue - } - - if strings.HasPrefix(expr[i:], op) { - builder.WriteString(op) - i += len(op) - - // Preserve whitespace following the operator. - wsStart := i - for i < n && (expr[i] == ' ' || expr[i] == '\t') { - i++ - } - builder.WriteString(expr[wsStart:i]) - - signStart := i - if i < n && (expr[i] == '+' || expr[i] == '-') { - i++ - } - for i < n && expr[i] >= '0' && expr[i] <= '9' { - i++ - } - if i > signStart { - numLiteral := expr[signStart:i] - fmt.Fprintf(&builder, "(%s != 0)", numLiteral) - } else { - builder.WriteString(expr[signStart:i]) - } - continue - } - - builder.WriteByte(ch) - i++ - } - - return builder.String() -} diff --git a/plugin/filter/engine_test.go b/plugin/filter/engine_test.go new file mode 100644 index 000000000..f9e72c224 --- /dev/null +++ b/plugin/filter/engine_test.go @@ -0,0 +1,39 @@ +package filter + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(NewSchema()) + require.NoError(t, err) + + _, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`) + require.NoError(t, err) +} + +func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(NewSchema()) + require.NoError(t, err) + + _, err = engine.Compile(context.Background(), `pinned && 1`) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to compile filter") +} + +func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(NewSchema()) + require.NoError(t, err) + + _, err = engine.Compile(context.Background(), `1`) + require.EqualError(t, err, "filter must evaluate to a boolean value") +} diff --git a/plugin/filter/ir.go b/plugin/filter/ir.go index 10cb13df1..b5a995dda 100644 --- a/plugin/filter/ir.go +++ b/plugin/filter/ir.go @@ -157,3 +157,10 @@ type ContainsPredicate struct { } func (*ContainsPredicate) isPredicateExpr() {} + +// EqualsPredicate represents t == "value". +type EqualsPredicate struct { + Value string +} + +func (*EqualsPredicate) isPredicateExpr() {} diff --git a/plugin/filter/parser.go b/plugin/filter/parser.go index 36e52d1db..2aff1074e 100644 --- a/plugin/filter/parser.go +++ b/plugin/filter/parser.go @@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) { if err != nil { return nil, err } - switch v := val.(type) { - case bool: + if v, ok := val.(bool); ok { return &ConstantCondition{Value: v}, nil - case int64: - return &ConstantCondition{Value: v != 0}, nil - case float64: - return &ConstantCondition{Value: v != 0}, nil - default: - return nil, errors.New("filter must evaluate to a boolean value") } + return nil, errors.New("filter must evaluate to a boolean value") case *exprv1.Expr_IdentExpr: name := v.IdentExpr.GetName() field, ok := schema.Field(name) @@ -504,6 +498,8 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, // Handle different predicate functions switch predicateCall.Function { + case "_==_": + return buildEqualsPredicate(predicateCall, comp.IterVar) case "startsWith": return buildStartsWithPredicate(predicateCall, comp.IterVar) case "endsWith": @@ -511,10 +507,44 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, case "contains": return buildContainsPredicate(predicateCall, comp.IterVar) default: - return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function) + return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function) } } +// buildEqualsPredicate extracts the value from t == "value". +func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { + if len(call.Args) != 2 { + return nil, errors.New("equality predicate expects exactly two arguments") + } + + var constExpr *exprv1.Expr + switch { + case isIterVarExpr(call.Args[0], iterVar): + constExpr = call.Args[1] + case isIterVarExpr(call.Args[1], iterVar): + constExpr = call.Args[0] + default: + return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar) + } + + value, err := getConstValue(constExpr) + if err != nil { + return nil, errors.Wrap(err, "equality argument must be a constant string") + } + + valueStr, ok := value.(string) + if !ok { + return nil, errors.New("equality argument must be a string") + } + + return &EqualsPredicate{Value: valueStr}, nil +} + +func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool { + target := expr.GetIdentExpr() + return target != nil && target.GetName() == iterVar +} + // buildStartsWithPredicate extracts the pattern from t.startsWith("prefix"). func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { // Verify the target is the iteration variable diff --git a/plugin/filter/render.go b/plugin/filter/render.go index c91096a7b..39eaaec01 100644 --- a/plugin/filter/render.go +++ b/plugin/filter/render.go @@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re // Render based on predicate type switch pred := cond.Predicate.(type) { + case *EqualsPredicate: + return r.renderTagEquals(field, pred.Value, cond.Kind) case *StartsWithPredicate: return r.renderTagStartsWith(field, pred.Prefix, cond.Kind) case *EndsWithPredicate: @@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re } } +// renderTagEquals generates SQL for tags.exists(t, t == "value"). +func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) { + arrayExpr := jsonArrayExpr(r.dialect, field) + + switch r.dialect { + case DialectSQLite, DialectMySQL: + exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value)) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil + case DialectPostgres: + exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value))) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil + default: + return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect) + } +} + // renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")). func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) { arrayExpr := jsonArrayExpr(r.dialect, field) diff --git a/server/router/mcp/README.md b/server/router/mcp/README.md index 78feb5732..0af991fc0 100644 --- a/server/router/mcp/README.md +++ b/server/router/mcp/README.md @@ -48,7 +48,7 @@ For Streamable HTTP safety, requests with an `Origin` header must be same-origin | Tool | Description | Required params | Optional params | |---|---|---|---| -| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) | +| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) | | `get_memo` | Get a single memo | `name` | — | | `search_memos` | Full-text search | `query` | — | | `create_memo` | Create a memo | `content` | `visibility` | diff --git a/server/router/mcp/access.go b/server/router/mcp/access.go index 0e950b228..eadc11f83 100644 --- a/server/router/mcp/access.go +++ b/server/router/mcp/access.go @@ -41,6 +41,10 @@ func checkMemoOwnership(memo *store.Memo, userID int32) error { return nil } +func hasMemoOwnership(memo *store.Memo, userID int32) bool { + return memo.CreatorID == userID +} + // applyVisibilityFilter restricts find to memos the caller may see. func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) { if rowStatus != nil && *rowStatus == store.Archived { diff --git a/server/router/mcp/tools_memo.go b/server/router/mcp/tools_memo.go index 2e106c805..47e8a2298 100644 --- a/server/router/mcp/tools_memo.go +++ b/server/router/mcp/tools_memo.go @@ -223,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) { mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"), ), mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")), - mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), + mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), ), s.handleListMemos) mcpSrv.AddTool(mcp.NewTool("get_memo", diff --git a/server/router/mcp/tools_relation.go b/server/router/mcp/tools_relation.go index 6a7e886ab..127bb16fe 100644 --- a/server/router/mcp/tools_relation.go +++ b/server/router/mcp/tools_relation.go @@ -142,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if err := checkMemoOwnership(srcMemo, userID); err != nil { + if !hasMemoOwnership(srcMemo, userID) { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } @@ -199,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if err := checkMemoOwnership(srcMemo, userID); err != nil { + if !hasMemoOwnership(srcMemo, userID) { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } diff --git a/store/test/memo_filter_test.go b/store/test/memo_filter_test.go index aaa25488d..09f49854c 100644 --- a/store/test/memo_filter_test.go +++ b/store/test/memo_filter_test.go @@ -730,6 +730,31 @@ func TestMemoFilterTagsExistsContains(t *testing.T) { require.Len(t, memos, 1, "Should find 1 non-todo memo") } +func TestMemoFilterTagsExistsEquals(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + tc.CreateMemo(NewMemoBuilder("memo-1231", tc.User.ID). + Content("Memo with exact numeric tag"). + Tags("1231", "project")) + + tc.CreateMemo(NewMemoBuilder("memo-1231-suffix", tc.User.ID). + Content("Memo with related tag"). + Tags("tag/1231", "other")) + + tc.CreateMemo(NewMemoBuilder("memo-other", tc.User.ID). + Content("Memo with different tag"). + Tags("9999")) + + memos := tc.ListWithFilter(`tags.exists(t, t == "1231")`) + require.Len(t, memos, 1, "Should find only the memo with exact matching tag") + require.Equal(t, "memo-1231", memos[0].UID) + + memos = tc.ListWithFilter(`!tags.exists(t, t == "1231")`) + require.Len(t, memos, 2, "Should exclude only the memo with exact matching tag") +} + func TestMemoFilterTagsExistsEndsWith(t *testing.T) { t.Parallel() tc := NewMemoFilterTestContext(t)