mirror of https://github.com/usememos/memos.git
fix(filter): enforce CEL syntax semantics
Reject non-standard truthy numeric expressions in filters and document the parser as a supported subset of standard CEL syntax. - remove legacy filter rewrites - support standard equality in tag exists predicates - add regression coverage for accepted and rejected expressions
This commit is contained in:
parent
d3f6e8ee31
commit
0e89407ee9
|
|
@ -1,12 +1,14 @@
|
|||
# Memo Filter Engine
|
||||
|
||||
This package houses the memo-only filter engine that turns CEL expressions into
|
||||
SQL fragments. The engine follows a three phase pipeline inspired by systems
|
||||
This package houses the memo-only filter engine that turns standard CEL syntax
|
||||
into SQL fragments for the subset of expressions supported by the memo schema.
|
||||
The engine follows a three phase pipeline inspired by systems
|
||||
such as Calcite or Prisma:
|
||||
|
||||
1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against
|
||||
the memo-specific environment declared in `schema.go`. Only fields that
|
||||
exist in the schema can surface in the filter.
|
||||
exist in the schema can surface in the filter, and non-standard legacy
|
||||
coercions are rejected.
|
||||
2. **Normalization** – the raw CEL AST is converted into an intermediate
|
||||
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
|
||||
conditions (logical operators, comparisons, list membership, etc.). This
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package filter
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
|||
return nil, errors.New("filter expression is empty")
|
||||
}
|
||||
|
||||
filter = normalizeLegacyFilter(filter)
|
||||
|
||||
ast, issues := e.env.Compile(filter)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
||||
|
|
@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) {
|
|||
})
|
||||
return defaultAttachmentInst, defaultAttachmentErr
|
||||
}
|
||||
|
||||
func normalizeLegacyFilter(expr string) string {
|
||||
expr = rewriteNumericLogicalOperand(expr, "&&")
|
||||
expr = rewriteNumericLogicalOperand(expr, "||")
|
||||
return expr
|
||||
}
|
||||
|
||||
func rewriteNumericLogicalOperand(expr, op string) string {
|
||||
var builder strings.Builder
|
||||
n := len(expr)
|
||||
i := 0
|
||||
var inQuote rune
|
||||
|
||||
for i < n {
|
||||
ch := expr[i]
|
||||
|
||||
if inQuote != 0 {
|
||||
builder.WriteByte(ch)
|
||||
if ch == '\\' && i+1 < n {
|
||||
builder.WriteByte(expr[i+1])
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if ch == byte(inQuote) {
|
||||
inQuote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inQuote = rune(ch)
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr[i:], op) {
|
||||
builder.WriteString(op)
|
||||
i += len(op)
|
||||
|
||||
// Preserve whitespace following the operator.
|
||||
wsStart := i
|
||||
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
builder.WriteString(expr[wsStart:i])
|
||||
|
||||
signStart := i
|
||||
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
||||
i++
|
||||
}
|
||||
if i > signStart {
|
||||
numLiteral := expr[signStart:i]
|
||||
fmt.Fprintf(&builder, "(%s != 0)", numLiteral)
|
||||
} else {
|
||||
builder.WriteString(expr[signStart:i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `pinned && 1`)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to compile filter")
|
||||
}
|
||||
|
||||
func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `1`)
|
||||
require.EqualError(t, err, "filter must evaluate to a boolean value")
|
||||
}
|
||||
|
|
@ -157,3 +157,10 @@ type ContainsPredicate struct {
|
|||
}
|
||||
|
||||
func (*ContainsPredicate) isPredicateExpr() {}
|
||||
|
||||
// EqualsPredicate represents t == "value".
|
||||
type EqualsPredicate struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*EqualsPredicate) isPredicateExpr() {}
|
||||
|
|
|
|||
|
|
@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v, ok := val.(bool); ok {
|
||||
return &ConstantCondition{Value: v}, nil
|
||||
case int64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
case float64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
default:
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
}
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
name := v.IdentExpr.GetName()
|
||||
field, ok := schema.Field(name)
|
||||
|
|
@ -504,6 +498,8 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
|
|||
|
||||
// Handle different predicate functions
|
||||
switch predicateCall.Function {
|
||||
case "_==_":
|
||||
return buildEqualsPredicate(predicateCall, comp.IterVar)
|
||||
case "startsWith":
|
||||
return buildStartsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "endsWith":
|
||||
|
|
@ -511,10 +507,44 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
|
|||
case "contains":
|
||||
return buildContainsPredicate(predicateCall, comp.IterVar)
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
|
||||
return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEqualsPredicate extracts the value from t == "value".
|
||||
func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("equality predicate expects exactly two arguments")
|
||||
}
|
||||
|
||||
var constExpr *exprv1.Expr
|
||||
switch {
|
||||
case isIterVarExpr(call.Args[0], iterVar):
|
||||
constExpr = call.Args[1]
|
||||
case isIterVarExpr(call.Args[1], iterVar):
|
||||
constExpr = call.Args[0]
|
||||
default:
|
||||
return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
value, err := getConstValue(constExpr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "equality argument must be a constant string")
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("equality argument must be a string")
|
||||
}
|
||||
|
||||
return &EqualsPredicate{Value: valueStr}, nil
|
||||
}
|
||||
|
||||
func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool {
|
||||
target := expr.GetIdentExpr()
|
||||
return target != nil && target.GetName() == iterVar
|
||||
}
|
||||
|
||||
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
|
||||
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
// Verify the target is the iteration variable
|
||||
|
|
|
|||
|
|
@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
|
|||
|
||||
// Render based on predicate type
|
||||
switch pred := cond.Predicate.(type) {
|
||||
case *EqualsPredicate:
|
||||
return r.renderTagEquals(field, pred.Value, cond.Kind)
|
||||
case *StartsWithPredicate:
|
||||
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
|
||||
case *EndsWithPredicate:
|
||||
|
|
@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
|
|||
}
|
||||
}
|
||||
|
||||
// renderTagEquals generates SQL for tags.exists(t, t == "value").
|
||||
func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value))
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
|
||||
case DialectPostgres:
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value)))
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
|
||||
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ For Streamable HTTP safety, requests with an `Origin` header must be same-origin
|
|||
|
||||
| Tool | Description | Required params | Optional params |
|
||||
|---|---|---|---|
|
||||
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) |
|
||||
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) |
|
||||
| `get_memo` | Get a single memo | `name` | — |
|
||||
| `search_memos` | Full-text search | `query` | — |
|
||||
| `create_memo` | Create a memo | `content` | `visibility` |
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ func checkMemoOwnership(memo *store.Memo, userID int32) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func hasMemoOwnership(memo *store.Memo, userID int32) bool {
|
||||
return memo.CreatorID == userID
|
||||
}
|
||||
|
||||
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||
func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) {
|
||||
if rowStatus != nil && *rowStatus == store.Archived {
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
|
|||
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
|
||||
),
|
||||
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
), s.handleListMemos)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("get_memo",
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if srcMemo == nil {
|
||||
return mcp.NewToolResultError("source memo not found"), nil
|
||||
}
|
||||
if err := checkMemoOwnership(srcMemo, userID); err != nil {
|
||||
if !hasMemoOwnership(srcMemo, userID) {
|
||||
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
|
||||
}
|
||||
|
||||
|
|
@ -199,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if srcMemo == nil {
|
||||
return mcp.NewToolResultError("source memo not found"), nil
|
||||
}
|
||||
if err := checkMemoOwnership(srcMemo, userID); err != nil {
|
||||
if !hasMemoOwnership(srcMemo, userID) {
|
||||
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -730,6 +730,31 @@ func TestMemoFilterTagsExistsContains(t *testing.T) {
|
|||
require.Len(t, memos, 1, "Should find 1 non-todo memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-1231", tc.User.ID).
|
||||
Content("Memo with exact numeric tag").
|
||||
Tags("1231", "project"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-1231-suffix", tc.User.ID).
|
||||
Content("Memo with related tag").
|
||||
Tags("tag/1231", "other"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-other", tc.User.ID).
|
||||
Content("Memo with different tag").
|
||||
Tags("9999"))
|
||||
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t == "1231")`)
|
||||
require.Len(t, memos, 1, "Should find only the memo with exact matching tag")
|
||||
require.Equal(t, "memo-1231", memos[0].UID)
|
||||
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t == "1231")`)
|
||||
require.Len(t, memos, 2, "Should exclude only the memo with exact matching tag")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
|
|
|
|||
Loading…
Reference in New Issue