diff --git a/plugin/filter/MAINTENANCE.md b/plugin/filter/MAINTENANCE.md new file mode 100644 index 000000000..f37a9f984 --- /dev/null +++ b/plugin/filter/MAINTENANCE.md @@ -0,0 +1,50 @@ +# Maintaining the Memo Filter Engine + +The engine is memo-specific; any future field or behavior changes must stay +consistent with the memo schema and store implementations. Use this guide when +extending or debugging the package. + +## Adding a New Memo Field + +1. **Update the schema** + - Add the field entry in `schema.go`. + - Define the backing column (`Column`), JSON path (if applicable), type, and + allowed operators. + - Include the CEL variable in `EnvOptions`. +2. **Adjust parser or renderer (if needed)** + - For non-scalar fields (JSON booleans, lists), add handling in + `parser.go` or extend the renderer helpers. + - Keep validation in the parser (e.g., reject unsupported operators). +3. **Write a golden test** + - Extend the dialect-specific memo filter tests under + `store/db/{sqlite,mysql,postgres}/memo_filter_test.go` with a case that + exercises the new field. +4. **Run `go test ./...`** to ensure the SQL output matches expectations across + all dialects. + +## Supporting Dialect Nuances + +- Centralize differences inside `render.go`. If a new dialect-specific behavior + emerges (e.g., JSON operators), add the logic there rather than leaking it + into store code. +- Use the renderer helpers (`jsonExtractExpr`, `jsonArrayExpr`, etc.) rather than + sprinkling ad-hoc SQL strings. +- When placeholders change, adjust `addArg` so that argument numbering stays in + sync with store queries. + +## Debugging Tips + +- **Parser errors** – Most originate in `buildCondition` or schema validation. + Enable logging around `parser.go` when diagnosing unknown identifier/operator + messages. +- **Renderer output** – Temporary printf/log statements in `renderCondition` help + identify which IR node produced unexpected SQL. +- **Store integration** – Ensure drivers call `filter.DefaultEngine()` exactly once + per process; the singleton caches the parsed CEL environment. + +## Testing Checklist + +- `go test ./store/...` ensures all dialect tests consume the engine correctly. +- Add targeted unit tests whenever new IR nodes or renderer paths are introduced. +- When changing boolean or JSON handling, verify all three dialect test suites + (SQLite, MySQL, Postgres) to avoid regression. diff --git a/plugin/filter/README.md b/plugin/filter/README.md new file mode 100644 index 000000000..35961615f --- /dev/null +++ b/plugin/filter/README.md @@ -0,0 +1,63 @@ +# Memo Filter Engine + +This package houses the memo-only filter engine that turns CEL expressions into +SQL fragments. The engine follows a three phase pipeline inspired by systems +such as Calcite or Prisma: + +1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against + the memo-specific environment declared in `schema.go`. Only fields that + exist in the schema can surface in the filter. +2. **Normalization** – the raw CEL AST is converted into an intermediate + representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of + conditions (logical operators, comparisons, list membership, etc.). This + step enforces schema rules (e.g. operator compatibility, type checks). +3. **Rendering** – the renderer in `render.go` walks the IR and produces a SQL + fragment plus placeholder arguments tailored to a target dialect + (`sqlite`, `mysql`, or `postgres`). Dialect differences such as JSON access, + boolean semantics, placeholders, and `LIKE` vs `ILIKE` are encapsulated in + renderer helpers. + +The entry point is `filter.DefaultEngine()` from `engine.go`. It lazily constructs +an `Engine` configured with the memo schema and exposes: + +```go +engine, _ := filter.DefaultEngine() +stmt, _ := engine.CompileToStatement(ctx, `has_task_list && visibility == "PUBLIC"`, filter.RenderOptions{ + Dialect: filter.DialectPostgres, +}) +// stmt.SQL -> "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.visibility = $1)" +// stmt.Args -> ["PUBLIC"] +``` + +## Core Files + +| File | Responsibility | +| ------------- | ------------------------------------------------------------------------------- | +| `schema.go` | Declares memo fields, their types, backing columns, CEL environment options | +| `ir.go` | IR node definitions used across the pipeline | +| `parser.go` | Converts CEL `Expr` into IR while applying schema validation | +| `render.go` | Translates IR into SQL, handling dialect-specific behavior | +| `engine.go` | Glue between the phases; exposes `Compile`, `CompileToStatement`, and `DefaultEngine` | +| `helpers.go` | Convenience helpers for store integration (appending conditions) | + +## SQL Generation Notes + +- **Placeholders** — `?` is used for SQLite/MySQL, `$n` for Postgres. The renderer + tracks offsets to compose queries with pre-existing arguments. +- **JSON Fields** — Memo metadata lives in `memo.payload`. The renderer handles + `JSON_EXTRACT`/`json_extract`/`->`/`->>` variations and boolean coercion. +- **Tag Operations** — `tag in [...]` and `"tag" in tags` become JSON array + predicates. SQLite uses `LIKE` patterns, MySQL uses `JSON_CONTAINS`, and + Postgres uses `@>`. +- **Boolean Flags** — Fields such as `has_task_list` render as `IS TRUE` equality + checks, or comparisons against `CAST('true' AS JSON)` depending on the dialect. + +## Typical Integration + +1. Fetch the engine with `filter.DefaultEngine()`. +2. Call `CompileToStatement` using the appropriate dialect enum. +3. Append the emitted SQL fragment/args to the existing `WHERE` clause. +4. Execute the resulting query through the store driver. + +The `helpers.AppendConditions` helper encapsulates steps 2–3 when a driver needs +to process an array of filters. diff --git a/plugin/filter/common_converter.go b/plugin/filter/common_converter.go deleted file mode 100644 index aa3942929..000000000 --- a/plugin/filter/common_converter.go +++ /dev/null @@ -1,746 +0,0 @@ -package filter - -import ( - "fmt" - "slices" - "strings" - - "github.com/pkg/errors" - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" -) - -// CommonSQLConverter handles the common CEL to SQL conversion logic. -type CommonSQLConverter struct { - dialect SQLDialect - paramIndex int - allowedFields []string - entityType string -} - -// NewCommonSQLConverter creates a new converter with the specified dialect for memo filters. -func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter { - return &CommonSQLConverter{ - dialect: dialect, - paramIndex: 1, - allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, - entityType: "memo", - } -} - -// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset for memo filters. -func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter { - return &CommonSQLConverter{ - dialect: dialect, - paramIndex: offset + 1, - allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, - entityType: "memo", - } -} - -// NewUserSQLConverter creates a new converter for user filters. -func NewUserSQLConverter(dialect SQLDialect) *CommonSQLConverter { - return &CommonSQLConverter{ - dialect: dialect, - paramIndex: 1, - allowedFields: []string{"username"}, - entityType: "user", - } -} - -// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect. -func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error { - if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { - switch v.CallExpr.Function { - case "_||_", "_&&_": - return c.handleLogicalOperator(ctx, v.CallExpr) - case "!_": - return c.handleNotOperator(ctx, v.CallExpr) - case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": - return c.handleComparisonOperator(ctx, v.CallExpr) - case "@in": - return c.handleInOperator(ctx, v.CallExpr) - case "contains": - return c.handleContainsOperator(ctx, v.CallExpr) - default: - return errors.Errorf("unsupported call expression function: %s", v.CallExpr.Function) - } - } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { - return c.handleIdentifier(ctx, v.IdentExpr) - } - return nil -} - -func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { - if len(callExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", callExpr.Function) - } - - if _, err := ctx.Buffer.WriteString("("); err != nil { - return err - } - - if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil { - return err - } - - operator := "AND" - if callExpr.Function == "_||_" { - operator = "OR" - } - - if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { - return err - } - - if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil { - return err - } - - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err - } - - return nil -} - -func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { - if len(callExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", callExpr.Function) - } - - if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { - return err - } - - if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil { - return err - } - - if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err - } - - return nil -} - -func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { - if len(callExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", callExpr.Function) - } - - // Check if the left side is a function call like size(tags) - if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { - if leftCallExpr.CallExpr.Function == "size" { - return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr) - } - } - - identifier, err := GetIdentExprName(callExpr.Args[0]) - if err != nil { - return err - } - - if !slices.Contains(c.allowedFields, identifier) { - return errors.Errorf("invalid identifier for %s", callExpr.Function) - } - - value, err := GetExprValue(callExpr.Args[1]) - if err != nil { - return err - } - - operator := c.getComparisonOperator(callExpr.Function) - - // Handle memo fields - if c.entityType == "memo" { - switch identifier { - case "created_ts", "updated_ts": - return c.handleTimestampComparison(ctx, identifier, operator, value) - case "visibility", "content": - return c.handleStringComparison(ctx, identifier, operator, value) - case "creator_id": - return c.handleIntComparison(ctx, identifier, operator, value) - case "pinned": - return c.handlePinnedComparison(ctx, operator, value) - case "has_task_list", "has_link", "has_code", "has_incomplete_tasks": - return c.handleBooleanComparison(ctx, identifier, operator, value) - default: - return errors.Errorf("unsupported identifier in comparison: %s", identifier) - } - } - - // Handle user fields - if c.entityType == "user" { - switch identifier { - case "username": - return c.handleUserStringComparison(ctx, identifier, operator, value) - default: - return errors.Errorf("unsupported user identifier in comparison: %s", identifier) - } - } - - return errors.Errorf("unsupported entity type: %s", c.entityType) -} - -func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error { - if len(sizeCall.Args) != 1 { - return errors.New("size function requires exactly one argument") - } - - identifier, err := GetIdentExprName(sizeCall.Args[0]) - if err != nil { - return err - } - - if identifier != "tags" { - return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) - } - - value, err := GetExprValue(callExpr.Args[1]) - if err != nil { - return err - } - - valueInt, ok := value.(int64) - if !ok { - return errors.New("size comparison value must be an integer") - } - - operator := c.getComparisonOperator(callExpr.Function) - - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", - c.dialect.GetJSONArrayLength("$.tags"), - operator, - c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - - ctx.Args = append(ctx.Args, valueInt) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { - if len(callExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", callExpr.Function) - } - - // Check if this is "element in collection" syntax - if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil { - if identifier == "tags" { - return c.handleElementInTags(ctx, callExpr.Args[0]) - } - return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier) - } - - // Original logic for "identifier in [list]" syntax - identifier, err := GetIdentExprName(callExpr.Args[0]) - if err != nil { - return err - } - - if !slices.Contains([]string{"tag", "visibility", "content_id", "memo_id"}, identifier) { - return errors.Errorf("invalid identifier for %s", callExpr.Function) - } - - values := []any{} - for _, element := range callExpr.Args[1].GetListExpr().Elements { - value, err := GetConstValue(element) - if err != nil { - return err - } - values = append(values, value) - } - - if identifier == "tag" { - return c.handleTagInList(ctx, values) - } else if identifier == "visibility" { - return c.handleVisibilityInList(ctx, values) - } else if identifier == "content_id" { - return c.handleContentIDInList(ctx, values) - } else if identifier == "memo_id" { - return c.handleMemoIDInList(ctx, values) - } - - return nil -} - -func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error { - element, err := GetConstValue(elementExpr) - if err != nil { - return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) - } - - // Use dialect-specific JSON contains logic - template := c.dialect.GetJSONContains("$.tags", "element") - sqlExpr := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1) - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - - // Handle args based on dialect - if _, ok := c.dialect.(*SQLiteDialect); ok { - // SQLite uses LIKE with pattern - ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element)) - } else { - // MySQL and PostgreSQL expect plain values - ctx.Args = append(ctx.Args, element) - } - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error { - subconditions := []string{} - args := []any{} - - for _, v := range values { - if _, ok := c.dialect.(*SQLiteDialect); ok { - subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern")) - args = append(args, fmt.Sprintf(`%%"%s"%%`, v)) - } else { - // Replace ? with proper placeholder for each dialect - template := c.dialect.GetJSONContains("$.tags", "element") - sql := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1) - subconditions = append(subconditions, sql) - args = append(args, fmt.Sprintf(`"%s"`, v)) - } - c.paramIndex++ - } - - if len(subconditions) == 1 { - if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, args...) - return nil -} - -func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error { - placeholders := []string{} - for range values { - placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex)) - c.paramIndex++ - } - - tablePrefix := c.dialect.GetTablePrefix("memo") - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, values...) - return nil -} - -func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values []any) error { - placeholders := []string{} - for range values { - placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex)) - c.paramIndex++ - } - - tablePrefix := c.dialect.GetTablePrefix("reaction") - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, values...) - return nil -} - -func (c *CommonSQLConverter) handleMemoIDInList(ctx *ConvertContext, values []any) error { - placeholders := []string{} - for range values { - placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex)) - c.paramIndex++ - } - - tablePrefix := c.dialect.GetTablePrefix("resource") - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.memo_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`memo_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, values...) - return nil -} - -func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { - if len(callExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", callExpr.Function) - } - - identifier, err := GetIdentExprName(callExpr.Target) - if err != nil { - return err - } - - if identifier != "content" { - return errors.Errorf("invalid identifier for %s", callExpr.Function) - } - - arg, err := GetConstValue(callExpr.Args[0]) - if err != nil { - return err - } - - tablePrefix := c.dialect.GetTablePrefix("memo") - - // PostgreSQL uses ILIKE and no backticks - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content ILIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error { - identifier := identExpr.GetName() - - // Only memo entity has boolean identifiers that can be used standalone - if c.entityType != "memo" { - return errors.Errorf("invalid identifier %s for entity type %s", identifier, c.entityType) - } - - if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { - return errors.Errorf("invalid identifier %s", identifier) - } - - if identifier == "pinned" { - tablePrefix := c.dialect.GetTablePrefix("memo") - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil { - return err - } - } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil { - return err - } - } - } else if identifier == "has_task_list" { - if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil { - return err - } - } else if identifier == "has_link" { - if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasLink")); err != nil { - return err - } - } else if identifier == "has_code" { - if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasCode")); err != nil { - return err - } - } else if identifier == "has_incomplete_tasks" { - if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasIncompleteTasks")); err != nil { - return err - } - } - - return nil -} - -func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error { - valueInt, ok := value.(int64) - if !ok { - return errors.New("invalid integer timestamp value") - } - - timestampField := c.dialect.GetTimestampComparison(field) - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - - ctx.Args = append(ctx.Args, valueInt) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", field) - } - - valueStr, ok := value.(string) - if !ok { - return errors.New("invalid string value") - } - - tablePrefix := c.dialect.GetTablePrefix("memo") - - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - // PostgreSQL doesn't use backticks - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } else { - // MySQL and SQLite use backticks - fieldName := field - if field == "visibility" { - fieldName = "`visibility`" - } else if field == "content" { - fieldName = "`content`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, valueStr) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleUserStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", field) - } - - valueStr, ok := value.(string) - if !ok { - return errors.New("invalid string value") - } - - tablePrefix := c.dialect.GetTablePrefix("user") - - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - // PostgreSQL doesn't use backticks - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } else { - // MySQL and SQLite use backticks - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, valueStr) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", field) - } - - valueInt, ok := value.(int64) - if !ok { - return errors.New("invalid int value") - } - - tablePrefix := c.dialect.GetTablePrefix("memo") - - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - // PostgreSQL doesn't use backticks - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } else { - // MySQL and SQLite use backticks - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { - return err - } - } - - ctx.Args = append(ctx.Args, valueInt) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handlePinnedComparison(ctx *ConvertContext, operator string, value interface{}) error { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for pinned field") - } - - valueBool, ok := value.(bool) - if !ok { - return errors.New("invalid boolean value for pinned field") - } - - tablePrefix := c.dialect.GetTablePrefix("memo") - - var sqlExpr string - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - sqlExpr = fmt.Sprintf("%s.pinned %s %s", tablePrefix, operator, c.dialect.GetParameterPlaceholder(c.paramIndex)) - } else { - sqlExpr = fmt.Sprintf("%s.`pinned` %s %s", tablePrefix, operator, c.dialect.GetParameterPlaceholder(c.paramIndex)) - } - - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - - ctx.Args = append(ctx.Args, c.dialect.GetBooleanValue(valueBool)) - c.paramIndex++ - - return nil -} - -func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error { - if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", field) - } - - valueBool, ok := value.(bool) - if !ok { - return errors.Errorf("invalid boolean value for %s", field) - } - - // Map field name to JSON path - var jsonPath string - switch field { - case "has_task_list": - jsonPath = "$.property.hasTaskList" - case "has_link": - jsonPath = "$.property.hasLink" - case "has_code": - jsonPath = "$.property.hasCode" - case "has_incomplete_tasks": - jsonPath = "$.property.hasIncompleteTasks" - default: - return errors.Errorf("unsupported boolean field: %s", field) - } - - // Special handling for SQLite based on field - if _, ok := c.dialect.(*SQLiteDialect); ok { - if field == "has_task_list" { - // has_task_list uses = 1 / = 0 / != 1 / != 0 - var sqlExpr string - if operator == "=" { - if valueBool { - sqlExpr = fmt.Sprintf("%s = 1", c.dialect.GetJSONExtract(jsonPath)) - } else { - sqlExpr = fmt.Sprintf("%s = 0", c.dialect.GetJSONExtract(jsonPath)) - } - } else { // operator == "!=" - if valueBool { - sqlExpr = fmt.Sprintf("%s != 1", c.dialect.GetJSONExtract(jsonPath)) - } else { - sqlExpr = fmt.Sprintf("%s != 0", c.dialect.GetJSONExtract(jsonPath)) - } - } - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - return nil - } - // Other fields use IS TRUE / NOT(... IS TRUE) - var sqlExpr string - if operator == "=" { - if valueBool { - sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath)) - } else { - sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath)) - } - } else { // operator == "!=" - if valueBool { - sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath)) - } else { - sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath)) - } - } - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - return nil - } - - // Special handling for MySQL - use raw operator with CAST - if _, ok := c.dialect.(*MySQLDialect); ok { - var sqlExpr string - boolStr := "false" - if valueBool { - boolStr = "true" - } - sqlExpr = fmt.Sprintf("%s %s CAST('%s' AS JSON)", c.dialect.GetJSONExtract(jsonPath), operator, boolStr) - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - return nil - } - - // Handle PostgreSQL differently - it uses the raw operator - if _, ok := c.dialect.(*PostgreSQLDialect); ok { - jsonExtract := c.dialect.GetJSONExtract(jsonPath) - - sqlExpr := fmt.Sprintf("(%s)::boolean %s %s", - jsonExtract, - operator, - c.dialect.GetParameterPlaceholder(c.paramIndex)) - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - ctx.Args = append(ctx.Args, valueBool) - c.paramIndex++ - return nil - } - - // Handle other dialects - if operator == "!=" { - valueBool = !valueBool - } - - sqlExpr := c.dialect.GetBooleanComparison(jsonPath, valueBool) - if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { - return err - } - - return nil -} - -func (*CommonSQLConverter) getComparisonOperator(function string) string { - switch function { - case "_==_": - return "=" - case "_!=_": - return "!=" - case "_<_": - return "<" - case "_>_": - return ">" - case "_<=_": - return "<=" - case "_>=_": - return ">=" - default: - return "=" - } -} diff --git a/plugin/filter/converter.go b/plugin/filter/converter.go deleted file mode 100644 index c55a395bb..000000000 --- a/plugin/filter/converter.go +++ /dev/null @@ -1,20 +0,0 @@ -package filter - -import ( - "strings" -) - -type ConvertContext struct { - Buffer strings.Builder - Args []any - // The offset of the next argument in the condition string. - // Mainly using for PostgreSQL. - ArgsOffset int -} - -func NewConvertContext() *ConvertContext { - return &ConvertContext{ - Buffer: strings.Builder{}, - Args: []any{}, - } -} diff --git a/plugin/filter/dialect.go b/plugin/filter/dialect.go deleted file mode 100644 index 293d7d078..000000000 --- a/plugin/filter/dialect.go +++ /dev/null @@ -1,215 +0,0 @@ -package filter - -import ( - "fmt" - "strings" -) - -// SQLDialect defines database-specific SQL generation methods. -type SQLDialect interface { - // Basic field access - GetTablePrefix(entityName string) string - GetParameterPlaceholder(index int) string - - // JSON operations - GetJSONExtract(path string) string - GetJSONArrayLength(path string) string - GetJSONContains(path, element string) string - GetJSONLike(path, pattern string) string - - // Boolean operations - GetBooleanValue(value bool) interface{} - GetBooleanComparison(path string, value bool) string - GetBooleanCheck(path string) string - - // Timestamp operations - GetTimestampComparison(field string) string - GetCurrentTimestamp() string -} - -// DatabaseType represents the type of database. -type DatabaseType string - -const ( - SQLite DatabaseType = "sqlite" - MySQL DatabaseType = "mysql" - PostgreSQL DatabaseType = "postgres" -) - -// GetDialect returns the appropriate dialect for the database type. -func GetDialect(dbType DatabaseType) SQLDialect { - switch dbType { - case SQLite: - return &SQLiteDialect{} - case MySQL: - return &MySQLDialect{} - case PostgreSQL: - return &PostgreSQLDialect{} - default: - return &SQLiteDialect{} // default fallback - } -} - -// SQLiteDialect implements SQLDialect for SQLite. -type SQLiteDialect struct{} - -func (*SQLiteDialect) GetTablePrefix(entityName string) string { - return fmt.Sprintf("`%s`", entityName) -} - -func (*SQLiteDialect) GetParameterPlaceholder(_ int) string { - return "?" -} - -func (d *SQLiteDialect) GetJSONExtract(path string) string { - return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path) -} - -func (d *SQLiteDialect) GetJSONArrayLength(path string) string { - return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path)) -} - -func (d *SQLiteDialect) GetJSONContains(path, _ string) string { - return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path)) -} - -func (d *SQLiteDialect) GetJSONLike(path, _ string) string { - return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path)) -} - -func (*SQLiteDialect) GetBooleanValue(value bool) interface{} { - if value { - return 1 - } - return 0 -} - -func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string { - if value { - return fmt.Sprintf("%s = 1", d.GetJSONExtract(path)) - } - return fmt.Sprintf("%s = 0", d.GetJSONExtract(path)) -} - -func (d *SQLiteDialect) GetBooleanCheck(path string) string { - return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path)) -} - -func (d *SQLiteDialect) GetTimestampComparison(field string) string { - return fmt.Sprintf("%s.`%s`", d.GetTablePrefix("memo"), field) -} - -func (*SQLiteDialect) GetCurrentTimestamp() string { - return "strftime('%s', 'now')" -} - -// MySQLDialect implements SQLDialect for MySQL. -type MySQLDialect struct{} - -func (*MySQLDialect) GetTablePrefix(entityName string) string { - return fmt.Sprintf("`%s`", entityName) -} - -func (*MySQLDialect) GetParameterPlaceholder(_ int) string { - return "?" -} - -func (d *MySQLDialect) GetJSONExtract(path string) string { - return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path) -} - -func (d *MySQLDialect) GetJSONArrayLength(path string) string { - return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path)) -} - -func (d *MySQLDialect) GetJSONContains(path, _ string) string { - return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path)) -} - -func (d *MySQLDialect) GetJSONLike(path, _ string) string { - return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path)) -} - -func (*MySQLDialect) GetBooleanValue(value bool) interface{} { - return value -} - -func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string { - if value { - return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path)) - } - return fmt.Sprintf("%s != CAST('true' AS JSON)", d.GetJSONExtract(path)) -} - -func (d *MySQLDialect) GetBooleanCheck(path string) string { - return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path)) -} - -func (d *MySQLDialect) GetTimestampComparison(field string) string { - return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix("memo"), field) -} - -func (*MySQLDialect) GetCurrentTimestamp() string { - return "UNIX_TIMESTAMP()" -} - -// PostgreSQLDialect implements SQLDialect for PostgreSQL. -type PostgreSQLDialect struct{} - -func (*PostgreSQLDialect) GetTablePrefix(entityName string) string { - return entityName -} - -func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string { - return fmt.Sprintf("$%d", index) -} - -func (d *PostgreSQLDialect) GetJSONExtract(path string) string { - // Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList' - parts := strings.Split(strings.TrimPrefix(path, "$."), ".") - result := fmt.Sprintf("%s.payload", d.GetTablePrefix("memo")) - for i, part := range parts { - if i == len(parts)-1 { - result += fmt.Sprintf("->>'%s'", part) - } else { - result += fmt.Sprintf("->'%s'", part) - } - } - return result -} - -func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string { - jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) - return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix("memo"), jsonPath) -} - -func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string { - jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) - return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath) -} - -func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string { - jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) - return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath) -} - -func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} { - return value -} - -func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string { - // Note: The parameter placeholder will be replaced by the caller - return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path)) -} - -func (d *PostgreSQLDialect) GetBooleanCheck(path string) string { - return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path)) -} - -func (d *PostgreSQLDialect) GetTimestampComparison(field string) string { - return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix("memo"), field) -} - -func (*PostgreSQLDialect) GetCurrentTimestamp() string { - return "EXTRACT(EPOCH FROM NOW())" -} diff --git a/plugin/filter/engine.go b/plugin/filter/engine.go new file mode 100644 index 000000000..25b5485f9 --- /dev/null +++ b/plugin/filter/engine.go @@ -0,0 +1,180 @@ +package filter + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/google/cel-go/cel" + "github.com/pkg/errors" +) + +// Engine parses CEL filters into a dialect-agnostic condition tree. +type Engine struct { + schema Schema + env *cel.Env +} + +// NewEngine builds a new Engine for the provided schema. +func NewEngine(schema Schema) (*Engine, error) { + env, err := cel.NewEnv(schema.EnvOptions...) + if err != nil { + return nil, errors.Wrap(err, "failed to create CEL environment") + } + return &Engine{ + schema: schema, + env: env, + }, nil +} + +// Program stores a compiled filter condition. +type Program struct { + schema Schema + condition Condition +} + +// ConditionTree exposes the underlying condition tree. +func (p *Program) ConditionTree() Condition { + return p.condition +} + +// Compile parses the filter string into an executable program. +func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) { + if strings.TrimSpace(filter) == "" { + return nil, errors.New("filter expression is empty") + } + + filter = normalizeLegacyFilter(filter) + + ast, issues := e.env.Compile(filter) + if issues != nil && issues.Err() != nil { + return nil, errors.Wrap(issues.Err(), "failed to compile filter") + } + parsed, err := cel.AstToParsedExpr(ast) + if err != nil { + return nil, errors.Wrap(err, "failed to convert AST") + } + + cond, err := buildCondition(parsed.GetExpr(), e.schema) + if err != nil { + return nil, err + } + + return &Program{ + schema: e.schema, + condition: cond, + }, nil +} + +// CompileToStatement compiles and renders the filter in a single step. +func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) { + program, err := e.Compile(ctx, filter) + if err != nil { + return Statement{}, err + } + return program.Render(opts) +} + +// RenderOptions configure SQL rendering. +type RenderOptions struct { + Dialect DialectName + PlaceholderOffset int + DisableNullChecks bool +} + +// Statement contains the rendered SQL fragment and its args. +type Statement struct { + SQL string + Args []any +} + +// Render converts the program into a dialect-specific SQL fragment. +func (p *Program) Render(opts RenderOptions) (Statement, error) { + renderer := newRenderer(p.schema, opts) + return renderer.Render(p.condition) +} + +var ( + defaultOnce sync.Once + defaultInst *Engine + defaultErr error +) + +// DefaultEngine returns the process-wide memo filter engine. +func DefaultEngine() (*Engine, error) { + defaultOnce.Do(func() { + defaultInst, defaultErr = NewEngine(NewSchema()) + }) + return defaultInst, defaultErr +} + +func normalizeLegacyFilter(expr string) string { + expr = rewriteNumericLogicalOperand(expr, "&&") + expr = rewriteNumericLogicalOperand(expr, "||") + return expr +} + +func rewriteNumericLogicalOperand(expr, op string) string { + var builder strings.Builder + n := len(expr) + i := 0 + var inQuote rune + + for i < n { + ch := expr[i] + + if inQuote != 0 { + builder.WriteByte(ch) + if ch == '\\' && i+1 < n { + builder.WriteByte(expr[i+1]) + i += 2 + continue + } + if ch == byte(inQuote) { + inQuote = 0 + } + i++ + continue + } + + if ch == '\'' || ch == '"' { + inQuote = rune(ch) + builder.WriteByte(ch) + i++ + continue + } + + if strings.HasPrefix(expr[i:], op) { + builder.WriteString(op) + i += len(op) + + // Preserve whitespace following the operator. + wsStart := i + for i < n && (expr[i] == ' ' || expr[i] == '\t') { + i++ + } + builder.WriteString(expr[wsStart:i]) + + signStart := i + if i < n && (expr[i] == '+' || expr[i] == '-') { + i++ + } + for i < n && expr[i] >= '0' && expr[i] <= '9' { + i++ + } + if i > signStart { + numLiteral := expr[signStart:i] + builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral)) + } else { + builder.WriteString(expr[signStart:i]) + } + continue + } + + builder.WriteByte(ch) + i++ + } + + return builder.String() +} diff --git a/plugin/filter/expr.go b/plugin/filter/expr.go deleted file mode 100644 index 01ce5395e..000000000 --- a/plugin/filter/expr.go +++ /dev/null @@ -1,127 +0,0 @@ -package filter - -import ( - "errors" - "time" - - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" -) - -// GetConstValue returns the constant value of the expression. -func GetConstValue(expr *exprv1.Expr) (any, error) { - v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr) - if !ok { - return nil, errors.New("invalid constant expression") - } - - switch v.ConstExpr.ConstantKind.(type) { - case *exprv1.Constant_StringValue: - return v.ConstExpr.GetStringValue(), nil - case *exprv1.Constant_Int64Value: - return v.ConstExpr.GetInt64Value(), nil - case *exprv1.Constant_Uint64Value: - return v.ConstExpr.GetUint64Value(), nil - case *exprv1.Constant_DoubleValue: - return v.ConstExpr.GetDoubleValue(), nil - case *exprv1.Constant_BoolValue: - return v.ConstExpr.GetBoolValue(), nil - default: - return nil, errors.New("unexpected constant type") - } -} - -// GetIdentExprName returns the name of the identifier expression. -func GetIdentExprName(expr *exprv1.Expr) (string, error) { - _, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr) - if !ok { - return "", errors.New("invalid identifier expression") - } - return expr.GetIdentExpr().GetName(), nil -} - -// GetFunctionValue evaluates CEL function calls and returns their value. -// This is specifically for time functions like now(). -func GetFunctionValue(expr *exprv1.Expr) (any, error) { - callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr) - if !ok { - return nil, errors.New("invalid function call expression") - } - - switch callExpr.CallExpr.Function { - case "now": - if len(callExpr.CallExpr.Args) != 0 { - return nil, errors.New("now() function takes no arguments") - } - return time.Now().Unix(), nil - case "_-_": - // Handle subtraction for expressions like "now() - 60 * 60 * 24" - if len(callExpr.CallExpr.Args) != 2 { - return nil, errors.New("subtraction requires exactly two arguments") - } - left, err := GetExprValue(callExpr.CallExpr.Args[0]) - if err != nil { - return nil, err - } - right, err := GetExprValue(callExpr.CallExpr.Args[1]) - if err != nil { - return nil, err - } - leftInt, ok1 := left.(int64) - rightInt, ok2 := right.(int64) - if !ok1 || !ok2 { - return nil, errors.New("subtraction operands must be integers") - } - return leftInt - rightInt, nil - case "_*_": - // Handle multiplication for expressions like "60 * 60 * 24" - if len(callExpr.CallExpr.Args) != 2 { - return nil, errors.New("multiplication requires exactly two arguments") - } - left, err := GetExprValue(callExpr.CallExpr.Args[0]) - if err != nil { - return nil, err - } - right, err := GetExprValue(callExpr.CallExpr.Args[1]) - if err != nil { - return nil, err - } - leftInt, ok1 := left.(int64) - rightInt, ok2 := right.(int64) - if !ok1 || !ok2 { - return nil, errors.New("multiplication operands must be integers") - } - return leftInt * rightInt, nil - case "_+_": - // Handle addition - if len(callExpr.CallExpr.Args) != 2 { - return nil, errors.New("addition requires exactly two arguments") - } - left, err := GetExprValue(callExpr.CallExpr.Args[0]) - if err != nil { - return nil, err - } - right, err := GetExprValue(callExpr.CallExpr.Args[1]) - if err != nil { - return nil, err - } - leftInt, ok1 := left.(int64) - rightInt, ok2 := right.(int64) - if !ok1 || !ok2 { - return nil, errors.New("addition operands must be integers") - } - return leftInt + rightInt, nil - default: - return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function) - } -} - -// GetExprValue attempts to get a value from an expression, trying constants first, then functions. -func GetExprValue(expr *exprv1.Expr) (any, error) { - // Try to get constant value first - if constValue, err := GetConstValue(expr); err == nil { - return constValue, nil - } - - // If not a constant, try to evaluate as a function - return GetFunctionValue(expr) -} diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go deleted file mode 100644 index dc4190deb..000000000 --- a/plugin/filter/filter.go +++ /dev/null @@ -1,66 +0,0 @@ -package filter - -import ( - "time" - - "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" - "github.com/pkg/errors" - exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" -) - -// MemoFilterCELAttributes are the CEL attributes for memo. -var MemoFilterCELAttributes = []cel.EnvOption{ - cel.Variable("content", cel.StringType), - cel.Variable("creator_id", cel.IntType), - cel.Variable("created_ts", cel.IntType), - cel.Variable("updated_ts", cel.IntType), - cel.Variable("pinned", cel.BoolType), - cel.Variable("tag", cel.StringType), - cel.Variable("tags", cel.ListType(cel.StringType)), - cel.Variable("visibility", cel.StringType), - cel.Variable("has_task_list", cel.BoolType), - cel.Variable("has_link", cel.BoolType), - cel.Variable("has_code", cel.BoolType), - cel.Variable("has_incomplete_tasks", cel.BoolType), - // Current timestamp function. - cel.Function("now", - cel.Overload("now", - []*cel.Type{}, - cel.IntType, - cel.FunctionBinding(func(_ ...ref.Val) ref.Val { - return types.Int(time.Now().Unix()) - }), - ), - ), -} - -// ReactionFilterCELAttributes are the CEL attributes for reaction. -var ReactionFilterCELAttributes = []cel.EnvOption{ - cel.Variable("content_id", cel.StringType), -} - -// UserFilterCELAttributes are the CEL attributes for user. -var UserFilterCELAttributes = []cel.EnvOption{ - cel.Variable("username", cel.StringType), -} - -// AttachmentFilterCELAttributes are the CEL attributes for user. -var AttachmentFilterCELAttributes = []cel.EnvOption{ - cel.Variable("memo_id", cel.StringType), -} - -// Parse parses the filter string and returns the parsed expression. -// The filter string should be a CEL expression. -func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) { - e, err := cel.NewEnv(opts...) - if err != nil { - return nil, errors.Wrap(err, "failed to create CEL environment") - } - ast, issues := e.Compile(filter) - if issues != nil { - return nil, errors.Errorf("failed to compile filter: %v", issues) - } - return cel.AstToParsedExpr(ast) -} diff --git a/plugin/filter/helpers.go b/plugin/filter/helpers.go new file mode 100644 index 000000000..1d3c3fe9d --- /dev/null +++ b/plugin/filter/helpers.go @@ -0,0 +1,25 @@ +package filter + +import ( + "context" + "fmt" +) + +// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args. +func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error { + for _, filterStr := range filters { + stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{ + Dialect: dialect, + PlaceholderOffset: len(*args), + }) + if err != nil { + return err + } + if stmt.SQL == "" { + continue + } + *where = append(*where, fmt.Sprintf("(%s)", stmt.SQL)) + *args = append(*args, stmt.Args...) + } + return nil +} diff --git a/plugin/filter/ir.go b/plugin/filter/ir.go new file mode 100644 index 000000000..cfdefc9d4 --- /dev/null +++ b/plugin/filter/ir.go @@ -0,0 +1,116 @@ +package filter + +// Condition represents a boolean expression derived from the CEL filter. +type Condition interface { + isCondition() +} + +// LogicalOperator enumerates the supported logical operators. +type LogicalOperator string + +const ( + LogicalAnd LogicalOperator = "AND" + LogicalOr LogicalOperator = "OR" +) + +// LogicalCondition composes two conditions with a logical operator. +type LogicalCondition struct { + Operator LogicalOperator + Left Condition + Right Condition +} + +func (*LogicalCondition) isCondition() {} + +// NotCondition negates a child condition. +type NotCondition struct { + Expr Condition +} + +func (*NotCondition) isCondition() {} + +// FieldPredicateCondition asserts that a field evaluates to true. +type FieldPredicateCondition struct { + Field string +} + +func (*FieldPredicateCondition) isCondition() {} + +// ComparisonOperator lists supported comparison operators. +type ComparisonOperator string + +const ( + CompareEq ComparisonOperator = "=" + CompareNeq ComparisonOperator = "!=" + CompareLt ComparisonOperator = "<" + CompareLte ComparisonOperator = "<=" + CompareGt ComparisonOperator = ">" + CompareGte ComparisonOperator = ">=" +) + +// ComparisonCondition represents a binary comparison. +type ComparisonCondition struct { + Left ValueExpr + Operator ComparisonOperator + Right ValueExpr +} + +func (*ComparisonCondition) isCondition() {} + +// InCondition represents an IN predicate with literal list values. +type InCondition struct { + Left ValueExpr + Values []ValueExpr +} + +func (*InCondition) isCondition() {} + +// ElementInCondition represents the CEL syntax `"value" in field`. +type ElementInCondition struct { + Element ValueExpr + Field string +} + +func (*ElementInCondition) isCondition() {} + +// ContainsCondition models the .contains() call. +type ContainsCondition struct { + Field string + Value string +} + +func (*ContainsCondition) isCondition() {} + +// ConstantCondition captures a literal boolean outcome. +type ConstantCondition struct { + Value bool +} + +func (*ConstantCondition) isCondition() {} + +// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison. +type ValueExpr interface { + isValueExpr() +} + +// FieldRef references a named schema field. +type FieldRef struct { + Name string +} + +func (*FieldRef) isValueExpr() {} + +// LiteralValue holds a literal scalar. +type LiteralValue struct { + Value interface{} +} + +func (*LiteralValue) isValueExpr() {} + +// FunctionValue captures simple function calls like size(tags). +type FunctionValue struct { + Name string + Args []ValueExpr +} + +func (*FunctionValue) isValueExpr() {} diff --git a/plugin/filter/parser.go b/plugin/filter/parser.go new file mode 100644 index 000000000..5af924c09 --- /dev/null +++ b/plugin/filter/parser.go @@ -0,0 +1,413 @@ +package filter + +import ( + "time" + + "github.com/pkg/errors" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) { + switch v := expr.ExprKind.(type) { + case *exprv1.Expr_CallExpr: + return buildCallCondition(v.CallExpr, schema) + case *exprv1.Expr_ConstExpr: + val, err := getConstValue(expr) + if err != nil { + return nil, err + } + switch v := val.(type) { + case bool: + return &ConstantCondition{Value: v}, nil + case int64: + return &ConstantCondition{Value: v != 0}, nil + case float64: + return &ConstantCondition{Value: v != 0}, nil + default: + return nil, errors.New("filter must evaluate to a boolean value") + } + case *exprv1.Expr_IdentExpr: + name := v.IdentExpr.GetName() + field, ok := schema.Field(name) + if !ok { + return nil, errors.Errorf("unknown identifier %q", name) + } + if field.Type != FieldTypeBool { + return nil, errors.Errorf("identifier %q is not boolean", name) + } + return &FieldPredicateCondition{Field: name}, nil + default: + return nil, errors.New("unsupported top-level expression") + } +} + +func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) { + switch call.Function { + case "_&&_": + if len(call.Args) != 2 { + return nil, errors.New("logical AND expects two arguments") + } + left, err := buildCondition(call.Args[0], schema) + if err != nil { + return nil, err + } + right, err := buildCondition(call.Args[1], schema) + if err != nil { + return nil, err + } + return &LogicalCondition{ + Operator: LogicalAnd, + Left: left, + Right: right, + }, nil + case "_||_": + if len(call.Args) != 2 { + return nil, errors.New("logical OR expects two arguments") + } + left, err := buildCondition(call.Args[0], schema) + if err != nil { + return nil, err + } + right, err := buildCondition(call.Args[1], schema) + if err != nil { + return nil, err + } + return &LogicalCondition{ + Operator: LogicalOr, + Left: left, + Right: right, + }, nil + case "!_": + if len(call.Args) != 1 { + return nil, errors.New("logical NOT expects one argument") + } + child, err := buildCondition(call.Args[0], schema) + if err != nil { + return nil, err + } + return &NotCondition{Expr: child}, nil + case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": + return buildComparisonCondition(call, schema) + case "@in": + return buildInCondition(call, schema) + case "contains": + return buildContainsCondition(call, schema) + default: + val, ok, err := evaluateBool(call) + if err != nil { + return nil, err + } + if ok { + return &ConstantCondition{Value: val}, nil + } + return nil, errors.Errorf("unsupported call expression %q", call.Function) + } +} + +func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) { + if len(call.Args) != 2 { + return nil, errors.New("comparison expects two arguments") + } + op, err := toComparisonOperator(call.Function) + if err != nil { + return nil, err + } + + left, err := buildValueExpr(call.Args[0], schema) + if err != nil { + return nil, err + } + right, err := buildValueExpr(call.Args[1], schema) + if err != nil { + return nil, err + } + + // If the left side is a field, validate allowed operators. + if field, ok := left.(*FieldRef); ok { + def, exists := schema.Field(field.Name) + if !exists { + return nil, errors.Errorf("unknown identifier %q", field.Name) + } + if def.Kind == FieldKindVirtualAlias { + def, exists = schema.ResolveAlias(field.Name) + if !exists { + return nil, errors.Errorf("invalid alias %q", field.Name) + } + } + if def.AllowedComparisonOps != nil { + if _, allowed := def.AllowedComparisonOps[op]; !allowed { + return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name) + } + } + } + + return &ComparisonCondition{ + Left: left, + Operator: op, + Right: right, + }, nil +} + +func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) { + if len(call.Args) != 2 { + return nil, errors.New("in operator expects two arguments") + } + + // Handle identifier in list syntax. + if identName, err := getIdentName(call.Args[0]); err == nil { + if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias { + if _, aliasOk := schema.ResolveAlias(identName); !aliasOk { + return nil, errors.Errorf("invalid alias %q", identName) + } + } else if !ok { + return nil, errors.Errorf("unknown identifier %q", identName) + } + + if listExpr := call.Args[1].GetListExpr(); listExpr != nil { + values := make([]ValueExpr, 0, len(listExpr.Elements)) + for _, element := range listExpr.Elements { + value, err := buildValueExpr(element, schema) + if err != nil { + return nil, err + } + values = append(values, value) + } + return &InCondition{ + Left: &FieldRef{Name: identName}, + Values: values, + }, nil + } + } + + // Handle "value in identifier" syntax. + if identName, err := getIdentName(call.Args[1]); err == nil { + if _, ok := schema.Field(identName); !ok { + return nil, errors.Errorf("unknown identifier %q", identName) + } + element, err := buildValueExpr(call.Args[0], schema) + if err != nil { + return nil, err + } + return &ElementInCondition{ + Element: element, + Field: identName, + }, nil + } + + return nil, errors.New("invalid use of in operator") +} + +func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) { + if call.Target == nil { + return nil, errors.New("contains requires a target") + } + targetName, err := getIdentName(call.Target) + if err != nil { + return nil, err + } + + field, ok := schema.Field(targetName) + if !ok { + return nil, errors.Errorf("unknown identifier %q", targetName) + } + if !field.SupportsContains { + return nil, errors.Errorf("identifier %q does not support contains()", targetName) + } + if len(call.Args) != 1 { + return nil, errors.New("contains expects exactly one argument") + } + value, err := getConstValue(call.Args[0]) + if err != nil { + return nil, errors.Wrap(err, "contains only supports literal arguments") + } + str, ok := value.(string) + if !ok { + return nil, errors.New("contains argument must be a string") + } + return &ContainsCondition{ + Field: targetName, + Value: str, + }, nil +} + +func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) { + if identName, err := getIdentName(expr); err == nil { + if _, ok := schema.Field(identName); !ok { + return nil, errors.Errorf("unknown identifier %q", identName) + } + return &FieldRef{Name: identName}, nil + } + + if literal, err := getConstValue(expr); err == nil { + return &LiteralValue{Value: literal}, nil + } + + if value, ok, err := evaluateNumeric(expr); err != nil { + return nil, err + } else if ok { + return &LiteralValue{Value: value}, nil + } + + if boolVal, ok, err := evaluateBoolExpr(expr); err != nil { + return nil, err + } else if ok { + return &LiteralValue{Value: boolVal}, nil + } + + if call := expr.GetCallExpr(); call != nil { + switch call.Function { + case "size": + if len(call.Args) != 1 { + return nil, errors.New("size() expects one argument") + } + arg, err := buildValueExpr(call.Args[0], schema) + if err != nil { + return nil, err + } + return &FunctionValue{ + Name: "size", + Args: []ValueExpr{arg}, + }, nil + case "now": + return &LiteralValue{Value: timeNowUnix()}, nil + case "_+_", "_-_", "_*_": + value, ok, err := evaluateNumeric(expr) + if err != nil { + return nil, err + } + if ok { + return &LiteralValue{Value: value}, nil + } + } + } + + return nil, errors.New("unsupported value expression") +} + +func toComparisonOperator(fn string) (ComparisonOperator, error) { + switch fn { + case "_==_": + return CompareEq, nil + case "_!=_": + return CompareNeq, nil + case "_<_": + return CompareLt, nil + case "_>_": + return CompareGt, nil + case "_<=_": + return CompareLte, nil + case "_>=_": + return CompareGte, nil + default: + return "", errors.Errorf("unsupported comparison operator %q", fn) + } +} + +func getIdentName(expr *exprv1.Expr) (string, error) { + if ident := expr.GetIdentExpr(); ident != nil { + return ident.GetName(), nil + } + return "", errors.New("expression is not an identifier") +} + +func getConstValue(expr *exprv1.Expr) (interface{}, error) { + v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr) + if !ok { + return nil, errors.New("expression is not a literal") + } + switch x := v.ConstExpr.ConstantKind.(type) { + case *exprv1.Constant_StringValue: + return v.ConstExpr.GetStringValue(), nil + case *exprv1.Constant_Int64Value: + return v.ConstExpr.GetInt64Value(), nil + case *exprv1.Constant_Uint64Value: + return int64(v.ConstExpr.GetUint64Value()), nil + case *exprv1.Constant_DoubleValue: + return v.ConstExpr.GetDoubleValue(), nil + case *exprv1.Constant_BoolValue: + return v.ConstExpr.GetBoolValue(), nil + case *exprv1.Constant_NullValue: + return nil, nil + default: + return nil, errors.Errorf("unsupported constant %T", x) + } +} + +func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) { + val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}}) + return val, ok, err +} + +func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) { + if literal, err := getConstValue(expr); err == nil { + if b, ok := literal.(bool); ok { + return b, true, nil + } + return false, false, nil + } + if call := expr.GetCallExpr(); call != nil && call.Function == "!_" { + if len(call.Args) != 1 { + return false, false, errors.New("NOT expects exactly one argument") + } + val, ok, err := evaluateBoolExpr(call.Args[0]) + if err != nil || !ok { + return false, false, err + } + return !val, true, nil + } + return false, false, nil +} + +func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) { + if literal, err := getConstValue(expr); err == nil { + switch v := literal.(type) { + case int64: + return v, true, nil + case float64: + return int64(v), true, nil + } + return 0, false, nil + } + + call := expr.GetCallExpr() + if call == nil { + return 0, false, nil + } + + switch call.Function { + case "now": + return timeNowUnix(), true, nil + case "_+_", "_-_", "_*_": + if len(call.Args) != 2 { + return 0, false, errors.New("arithmetic requires two arguments") + } + left, ok, err := evaluateNumeric(call.Args[0]) + if err != nil { + return 0, false, err + } + if !ok { + return 0, false, nil + } + right, ok, err := evaluateNumeric(call.Args[1]) + if err != nil { + return 0, false, err + } + if !ok { + return 0, false, nil + } + switch call.Function { + case "_+_": + return left + right, true, nil + case "_-_": + return left - right, true, nil + case "_*_": + return left * right, true, nil + } + } + + return 0, false, nil +} + +func timeNowUnix() int64 { + return time.Now().Unix() +} diff --git a/plugin/filter/render.go b/plugin/filter/render.go new file mode 100644 index 000000000..6adf68fad --- /dev/null +++ b/plugin/filter/render.go @@ -0,0 +1,626 @@ +package filter + +import ( + "fmt" + "strings" + + "github.com/pkg/errors" +) + +type renderer struct { + schema Schema + dialect DialectName + placeholderOffset int + placeholderCounter int + args []any +} + +type renderResult struct { + sql string + trivial bool + unsatisfiable bool +} + +func newRenderer(schema Schema, opts RenderOptions) *renderer { + return &renderer{ + schema: schema, + dialect: opts.Dialect, + placeholderOffset: opts.PlaceholderOffset, + } +} + +func (r *renderer) Render(cond Condition) (Statement, error) { + result, err := r.renderCondition(cond) + if err != nil { + return Statement{}, err + } + args := r.args + if args == nil { + args = []any{} + } + + switch { + case result.unsatisfiable: + return Statement{ + SQL: "1 = 0", + Args: args, + }, nil + case result.trivial: + return Statement{ + SQL: "", + Args: args, + }, nil + default: + return Statement{ + SQL: result.sql, + Args: args, + }, nil + } +} + +func (r *renderer) renderCondition(cond Condition) (renderResult, error) { + switch c := cond.(type) { + case *LogicalCondition: + return r.renderLogicalCondition(c) + case *NotCondition: + return r.renderNotCondition(c) + case *FieldPredicateCondition: + return r.renderFieldPredicate(c) + case *ComparisonCondition: + return r.renderComparison(c) + case *InCondition: + return r.renderInCondition(c) + case *ElementInCondition: + return r.renderElementInCondition(c) + case *ContainsCondition: + return r.renderContainsCondition(c) + case *ConstantCondition: + if c.Value { + return renderResult{trivial: true}, nil + } + return renderResult{sql: "1 = 0", unsatisfiable: true}, nil + default: + return renderResult{}, errors.Errorf("unsupported condition type %T", c) + } +} + +func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) { + left, err := r.renderCondition(cond.Left) + if err != nil { + return renderResult{}, err + } + right, err := r.renderCondition(cond.Right) + if err != nil { + return renderResult{}, err + } + + switch cond.Operator { + case LogicalAnd: + return combineAnd(left, right), nil + case LogicalOr: + return combineOr(left, right), nil + default: + return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator) + } +} + +func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) { + child, err := r.renderCondition(cond.Expr) + if err != nil { + return renderResult{}, err + } + + if child.trivial { + return renderResult{sql: "1 = 0", unsatisfiable: true}, nil + } + if child.unsatisfiable { + return renderResult{trivial: true}, nil + } + return renderResult{ + sql: fmt.Sprintf("NOT (%s)", child.sql), + }, nil +} + +func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) { + field, ok := r.schema.Field(cond.Field) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", cond.Field) + } + + switch field.Kind { + case FieldKindBoolColumn: + column := qualifyColumn(r.dialect, field.Column) + return renderResult{ + sql: fmt.Sprintf("%s IS TRUE", column), + }, nil + case FieldKindJSONBool: + sql, err := r.jsonBoolPredicate(field) + if err != nil { + return renderResult{}, err + } + return renderResult{sql: sql}, nil + default: + return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field) + } +} + +func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) { + switch left := cond.Left.(type) { + case *FieldRef: + field, ok := r.schema.Field(left.Name) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", left.Name) + } + switch field.Kind { + case FieldKindBoolColumn: + return r.renderBoolColumnComparison(field, cond.Operator, cond.Right) + case FieldKindJSONBool: + return r.renderJSONBoolComparison(field, cond.Operator, cond.Right) + case FieldKindScalar: + return r.renderScalarComparison(field, cond.Operator, cond.Right) + default: + return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name) + } + case *FunctionValue: + return r.renderFunctionComparison(left, cond.Operator, cond.Right) + default: + return renderResult{}, errors.New("comparison must start with a field reference or supported function") + } +} + +func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) { + if fn.Name != "size" { + return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name) + } + if len(fn.Args) != 1 { + return renderResult{}, errors.New("size() expects one argument") + } + fieldArg, ok := fn.Args[0].(*FieldRef) + if !ok { + return renderResult{}, errors.New("size() argument must be a field") + } + + field, ok := r.schema.Field(fieldArg.Name) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name) + } + if field.Kind != FieldKindJSONList { + return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name) + } + + value, err := expectNumericLiteral(right) + if err != nil { + return renderResult{}, err + } + + expr := jsonArrayLengthExpr(r.dialect, field) + placeholder := r.addArg(value) + return renderResult{ + sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder), + }, nil +} + +func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) { + lit, err := expectLiteral(right) + if err != nil { + return renderResult{}, err + } + + columnExpr := field.columnExpr(r.dialect) + placeholder := "" + switch field.Type { + case FieldTypeString: + value, ok := lit.(string) + if !ok { + return renderResult{}, errors.Errorf("field %q expects string value", field.Name) + } + placeholder = r.addArg(value) + case FieldTypeInt, FieldTypeTimestamp: + num, err := toInt64(lit) + if err != nil { + return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name) + } + placeholder = r.addArg(num) + default: + return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name) + } + + return renderResult{ + sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder), + }, nil +} + +func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) { + value, err := expectBool(right) + if err != nil { + return renderResult{}, err + } + placeholder := r.addBoolArg(value) + column := qualifyColumn(r.dialect, field.Column) + return renderResult{ + sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder), + }, nil +} + +func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) { + value, err := expectBool(right) + if err != nil { + return renderResult{}, err + } + + jsonExpr := jsonExtractExpr(r.dialect, field) + switch r.dialect { + case DialectSQLite: + switch op { + case CompareEq: + if field.Name == "has_task_list" { + target := "0" + if value { + target = "1" + } + return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil + } + if value { + return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil + } + return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil + case CompareNeq: + if field.Name == "has_task_list" { + target := "0" + if value { + target = "1" + } + return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil + } + if value { + return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil + } + return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil + default: + return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op) + } + case DialectMySQL: + boolStr := "false" + if value { + boolStr = "true" + } + return renderResult{ + sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr), + }, nil + case DialectPostgres: + placeholder := r.addArg(value) + return renderResult{ + sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder), + }, nil + default: + return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect) + } +} + +func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) { + fieldRef, ok := cond.Left.(*FieldRef) + if !ok { + return renderResult{}, errors.New("IN operator requires a field on the left-hand side") + } + + if fieldRef.Name == "tag" { + return r.renderTagInList(cond.Values) + } + + field, ok := r.schema.Field(fieldRef.Name) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name) + } + + if field.Kind != FieldKindScalar { + return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name) + } + + return r.renderScalarInCondition(field, cond.Values) +} + +func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) { + field, ok := r.schema.ResolveAlias("tag") + if !ok { + return renderResult{}, errors.New("tag attribute is not configured") + } + + conditions := make([]string, 0, len(values)) + for _, v := range values { + lit, err := expectLiteral(v) + if err != nil { + return renderResult{}, err + } + str, ok := lit.(string) + if !ok { + return renderResult{}, errors.New("tags must be compared with string literals") + } + + switch r.dialect { + case DialectSQLite: + expr := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str))) + conditions = append(conditions, expr) + case DialectMySQL: + expr := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str))) + conditions = append(conditions, expr) + case DialectPostgres: + expr := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str))) + conditions = append(conditions, expr) + default: + return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect) + } + } + + if len(conditions) == 1 { + return renderResult{sql: conditions[0]}, nil + } + return renderResult{ + sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), + }, nil +} + +func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) { + field, ok := r.schema.Field(cond.Field) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", cond.Field) + } + if field.Kind != FieldKindJSONList { + return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field) + } + + lit, err := expectLiteral(cond.Element) + if err != nil { + return renderResult{}, err + } + str, ok := lit.(string) + if !ok { + return renderResult{}, errors.New("tags membership requires string literal") + } + + switch r.dialect { + case DialectSQLite: + sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str))) + return renderResult{sql: sql}, nil + case DialectMySQL: + sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(str)) + return renderResult{sql: sql}, nil + case DialectPostgres: + sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(str)) + return renderResult{sql: sql}, nil + default: + return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect) + } +} + +func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) { + placeholders := make([]string, 0, len(values)) + + for _, v := range values { + lit, err := expectLiteral(v) + if err != nil { + return renderResult{}, err + } + switch field.Type { + case FieldTypeString: + str, ok := lit.(string) + if !ok { + return renderResult{}, errors.Errorf("field %q expects string values", field.Name) + } + placeholders = append(placeholders, r.addArg(str)) + case FieldTypeInt: + num, err := toInt64(lit) + if err != nil { + return renderResult{}, err + } + placeholders = append(placeholders, r.addArg(num)) + default: + return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name) + } + } + + column := field.columnExpr(r.dialect) + return renderResult{ + sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), + }, nil +} + +func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) { + field, ok := r.schema.Field(cond.Field) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", cond.Field) + } + column := field.columnExpr(r.dialect) + arg := fmt.Sprintf("%%%s%%", cond.Value) + switch r.dialect { + case DialectPostgres: + sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg)) + return renderResult{sql: sql}, nil + default: + sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg)) + return renderResult{sql: sql}, nil + } +} + +func (r *renderer) jsonBoolPredicate(field Field) (string, error) { + expr := jsonExtractExpr(r.dialect, field) + switch r.dialect { + case DialectSQLite: + return fmt.Sprintf("%s IS TRUE", expr), nil + case DialectMySQL: + return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil + case DialectPostgres: + return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil + default: + return "", errors.Errorf("unsupported dialect %s", r.dialect) + } +} + +func combineAnd(left, right renderResult) renderResult { + if left.unsatisfiable || right.unsatisfiable { + return renderResult{sql: "1 = 0", unsatisfiable: true} + } + if left.trivial { + return right + } + if right.trivial { + return left + } + return renderResult{ + sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql), + } +} + +func combineOr(left, right renderResult) renderResult { + if left.trivial || right.trivial { + return renderResult{trivial: true} + } + if left.unsatisfiable { + return right + } + if right.unsatisfiable { + return left + } + return renderResult{ + sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql), + } +} + +func (r *renderer) addArg(value any) string { + r.placeholderCounter++ + r.args = append(r.args, value) + if r.dialect == DialectPostgres { + return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter) + } + return "?" +} + +func (r *renderer) addBoolArg(value bool) string { + var v any + switch r.dialect { + case DialectSQLite: + if value { + v = 1 + } else { + v = 0 + } + default: + v = value + } + return r.addArg(v) +} + +func expectLiteral(expr ValueExpr) (any, error) { + lit, ok := expr.(*LiteralValue) + if !ok { + return nil, errors.New("expression must be a literal") + } + return lit.Value, nil +} + +func expectBool(expr ValueExpr) (bool, error) { + lit, err := expectLiteral(expr) + if err != nil { + return false, err + } + value, ok := lit.(bool) + if !ok { + return false, errors.New("boolean literal required") + } + return value, nil +} + +func expectNumericLiteral(expr ValueExpr) (int64, error) { + lit, err := expectLiteral(expr) + if err != nil { + return 0, err + } + return toInt64(lit) +} + +func toInt64(value any) (int64, error) { + switch v := value.(type) { + case int: + return int64(v), nil + case int32: + return int64(v), nil + case int64: + return v, nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case float32: + return int64(v), nil + case float64: + return int64(v), nil + default: + return 0, errors.Errorf("cannot convert %T to int64", value) + } +} + +func sqlOperator(op ComparisonOperator) string { + return string(op) +} + +func qualifyColumn(d DialectName, col Column) string { + switch d { + case DialectPostgres: + return fmt.Sprintf("%s.%s", col.Table, col.Name) + default: + return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name) + } +} + +func jsonPath(field Field) string { + return "$." + strings.Join(field.JSONPath, ".") +} + +func jsonExtractExpr(d DialectName, field Field) string { + column := qualifyColumn(d, field.Column) + switch d { + case DialectSQLite, DialectMySQL: + return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field)) + case DialectPostgres: + return buildPostgresJSONAccessor(column, field.JSONPath, true) + default: + return "" + } +} + +func jsonArrayExpr(d DialectName, field Field) string { + column := qualifyColumn(d, field.Column) + switch d { + case DialectSQLite, DialectMySQL: + return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field)) + case DialectPostgres: + return buildPostgresJSONAccessor(column, field.JSONPath, false) + default: + return "" + } +} + +func jsonArrayLengthExpr(d DialectName, field Field) string { + arrayExpr := jsonArrayExpr(d, field) + switch d { + case DialectSQLite: + return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr) + case DialectMySQL: + return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr) + case DialectPostgres: + return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr) + default: + return "" + } +} + +func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string { + expr := base + for idx, part := range path { + if idx == len(path)-1 && terminalText { + expr = fmt.Sprintf("%s->>'%s'", expr, part) + } else { + expr = fmt.Sprintf("%s->'%s'", expr, part) + } + } + return expr +} diff --git a/plugin/filter/schema.go b/plugin/filter/schema.go new file mode 100644 index 000000000..4d5e3b4dc --- /dev/null +++ b/plugin/filter/schema.go @@ -0,0 +1,254 @@ +package filter + +import ( + "fmt" + "time" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// DialectName enumerates supported SQL dialects. +type DialectName string + +const ( + DialectSQLite DialectName = "sqlite" + DialectMySQL DialectName = "mysql" + DialectPostgres DialectName = "postgres" +) + +// FieldType represents the logical type of a field. +type FieldType string + +const ( + FieldTypeString FieldType = "string" + FieldTypeInt FieldType = "int" + FieldTypeBool FieldType = "bool" + FieldTypeTimestamp FieldType = "timestamp" +) + +// FieldKind describes how a field is stored. +type FieldKind string + +const ( + FieldKindScalar FieldKind = "scalar" + FieldKindBoolColumn FieldKind = "bool_column" + FieldKindJSONBool FieldKind = "json_bool" + FieldKindJSONList FieldKind = "json_list" + FieldKindVirtualAlias FieldKind = "virtual_alias" +) + +// Column identifies the backing table column. +type Column struct { + Table string + Name string +} + +// Field captures the schema metadata for an exposed CEL identifier. +type Field struct { + Name string + Kind FieldKind + Type FieldType + Column Column + JSONPath []string + AliasFor string + SupportsContains bool + Expressions map[DialectName]string + AllowedComparisonOps map[ComparisonOperator]bool +} + +// Schema collects CEL environment options and field metadata. +type Schema struct { + Name string + Fields map[string]Field + EnvOptions []cel.EnvOption +} + +// Field returns the field metadata if present. +func (s Schema) Field(name string) (Field, bool) { + f, ok := s.Fields[name] + return f, ok +} + +// ResolveAlias resolves a virtual alias to its target field. +func (s Schema) ResolveAlias(name string) (Field, bool) { + field, ok := s.Fields[name] + if !ok { + return Field{}, false + } + if field.Kind == FieldKindVirtualAlias { + target, ok := s.Fields[field.AliasFor] + if !ok { + return Field{}, false + } + return target, true + } + return field, true +} + +var nowFunction = cel.Function("now", + cel.Overload("now", + []*cel.Type{}, + cel.IntType, + cel.FunctionBinding(func(_ ...ref.Val) ref.Val { + return types.Int(time.Now().Unix()) + }), + ), +) + +// NewSchema constructs the memo filter schema and CEL environment. +func NewSchema() Schema { + fields := map[string]Field{ + "content": { + Name: "content", + Kind: FieldKindScalar, + Type: FieldTypeString, + Column: Column{Table: "memo", Name: "content"}, + SupportsContains: true, + Expressions: map[DialectName]string{}, + }, + "creator_id": { + Name: "creator_id", + Kind: FieldKindScalar, + Type: FieldTypeInt, + Column: Column{Table: "memo", Name: "creator_id"}, + Expressions: map[DialectName]string{}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + "created_ts": { + Name: "created_ts", + Kind: FieldKindScalar, + Type: FieldTypeTimestamp, + Column: Column{Table: "memo", Name: "created_ts"}, + Expressions: map[DialectName]string{ + DialectMySQL: "UNIX_TIMESTAMP(%s)", + DialectPostgres: "EXTRACT(EPOCH FROM TO_TIMESTAMP(%s))", + }, + }, + "updated_ts": { + Name: "updated_ts", + Kind: FieldKindScalar, + Type: FieldTypeTimestamp, + Column: Column{Table: "memo", Name: "updated_ts"}, + Expressions: map[DialectName]string{ + DialectMySQL: "UNIX_TIMESTAMP(%s)", + DialectPostgres: "EXTRACT(EPOCH FROM TO_TIMESTAMP(%s))", + }, + }, + "pinned": { + Name: "pinned", + Kind: FieldKindBoolColumn, + Type: FieldTypeBool, + Column: Column{Table: "memo", Name: "pinned"}, + Expressions: map[DialectName]string{}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + "visibility": { + Name: "visibility", + Kind: FieldKindScalar, + Type: FieldTypeString, + Column: Column{Table: "memo", Name: "visibility"}, + Expressions: map[DialectName]string{}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + "tags": { + Name: "tags", + Kind: FieldKindJSONList, + Type: FieldTypeString, + Column: Column{Table: "memo", Name: "payload"}, + JSONPath: []string{"tags"}, + }, + "tag": { + Name: "tag", + Kind: FieldKindVirtualAlias, + Type: FieldTypeString, + AliasFor: "tags", + }, + "has_task_list": { + Name: "has_task_list", + Kind: FieldKindJSONBool, + Type: FieldTypeBool, + Column: Column{Table: "memo", Name: "payload"}, + JSONPath: []string{"property", "hasTaskList"}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + "has_link": { + Name: "has_link", + Kind: FieldKindJSONBool, + Type: FieldTypeBool, + Column: Column{Table: "memo", Name: "payload"}, + JSONPath: []string{"property", "hasLink"}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + "has_code": { + Name: "has_code", + Kind: FieldKindJSONBool, + Type: FieldTypeBool, + Column: Column{Table: "memo", Name: "payload"}, + JSONPath: []string{"property", "hasCode"}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + "has_incomplete_tasks": { + Name: "has_incomplete_tasks", + Kind: FieldKindJSONBool, + Type: FieldTypeBool, + Column: Column{Table: "memo", Name: "payload"}, + JSONPath: []string{"property", "hasIncompleteTasks"}, + AllowedComparisonOps: map[ComparisonOperator]bool{ + CompareEq: true, + CompareNeq: true, + }, + }, + } + + envOptions := []cel.EnvOption{ + cel.Variable("content", cel.StringType), + cel.Variable("creator_id", cel.IntType), + cel.Variable("created_ts", cel.IntType), + cel.Variable("updated_ts", cel.IntType), + cel.Variable("pinned", cel.BoolType), + cel.Variable("tag", cel.StringType), + cel.Variable("tags", cel.ListType(cel.StringType)), + cel.Variable("visibility", cel.StringType), + cel.Variable("has_task_list", cel.BoolType), + cel.Variable("has_link", cel.BoolType), + cel.Variable("has_code", cel.BoolType), + cel.Variable("has_incomplete_tasks", cel.BoolType), + nowFunction, + } + + return Schema{ + Name: "memo", + Fields: fields, + EnvOptions: envOptions, + } +} + +// columnExpr returns the field expression for the given dialect, applying +// any schema-specific overrides (e.g. UNIX timestamp conversions). +func (f Field) columnExpr(d DialectName) string { + base := qualifyColumn(d, f.Column) + if expr, ok := f.Expressions[d]; ok && expr != "" { + return fmt.Sprintf(expr, base) + } + return base +} diff --git a/plugin/filter/templates.go b/plugin/filter/templates.go deleted file mode 100644 index 73e1f1df3..000000000 --- a/plugin/filter/templates.go +++ /dev/null @@ -1,146 +0,0 @@ -package filter - -import ( - "fmt" -) - -// SQLTemplate holds database-specific SQL fragments. -type SQLTemplate struct { - SQLite string - MySQL string - PostgreSQL string -} - -// TemplateDBType represents the database type for templates. -type TemplateDBType string - -const ( - SQLiteTemplate TemplateDBType = "sqlite" - MySQLTemplate TemplateDBType = "mysql" - PostgreSQLTemplate TemplateDBType = "postgres" -) - -// SQLTemplates contains common SQL patterns for different databases. -var SQLTemplates = map[string]SQLTemplate{ - "json_extract": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')", - PostgreSQL: "memo.payload%s", - }, - "json_array_length": { - SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))", - MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))", - PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))", - }, - "json_contains_element": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?", - MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)", - PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)", - }, - "json_contains_tag": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?", - MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)", - PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)", - }, - "boolean_true": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)", - PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true", - }, - "boolean_false": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)", - PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false", - }, - "boolean_not_true": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)", - PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true", - }, - "boolean_not_false": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)", - PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false", - }, - "boolean_compare": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)", - PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?", - }, - "boolean_check": { - SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE", - MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)", - PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE", - }, - "table_prefix": { - SQLite: "`memo`", - MySQL: "`memo`", - PostgreSQL: "memo", - }, - "timestamp_field": { - SQLite: "`memo`.`%s`", - MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)", - PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)", - }, - "content_like": { - SQLite: "`memo`.`content` LIKE ?", - MySQL: "`memo`.`content` LIKE ?", - PostgreSQL: "memo.content ILIKE ?", - }, - "visibility_in": { - SQLite: "`memo`.`visibility` IN (%s)", - MySQL: "`memo`.`visibility` IN (%s)", - PostgreSQL: "memo.visibility IN (%s)", - }, -} - -// GetSQL returns the appropriate SQL for the given template and database type. -func GetSQL(templateName string, dbType TemplateDBType) string { - template, exists := SQLTemplates[templateName] - if !exists { - return "" - } - - switch dbType { - case SQLiteTemplate: - return template.SQLite - case MySQLTemplate: - return template.MySQL - case PostgreSQLTemplate: - return template.PostgreSQL - default: - return template.SQLite - } -} - -// GetParameterPlaceholder returns the appropriate parameter placeholder for the database. -func GetParameterPlaceholder(dbType TemplateDBType, index int) string { - switch dbType { - case PostgreSQLTemplate: - return fmt.Sprintf("$%d", index) - default: - return "?" - } -} - -// GetParameterValue returns the appropriate parameter value for the database. -func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} { - switch templateName { - case "json_contains_element", "json_contains_tag": - if dbType == SQLiteTemplate { - return fmt.Sprintf(`%%"%s"%%`, value) - } - return value - default: - return value - } -} - -// FormatPlaceholders formats a list of placeholders for the given database type. -func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string { - placeholders := make([]string, count) - for i := 0; i < count; i++ { - placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i) - } - return placeholders -} diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index aa111338f..797cfd6ed 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -200,20 +200,18 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq } reactionMap := make(map[string][]*store.Reaction) - memoNames := make([]string, 0, len(memos)) + contentIDs := make([]string, 0, len(memos)) attachmentMap := make(map[int32][]*store.Attachment) - memoIDs := make([]string, 0, len(memos)) + memoIDs := make([]int32, 0, len(memos)) for _, m := range memos { - memoNames = append(memoNames, fmt.Sprintf("'%s%s'", MemoNamePrefix, m.UID)) - memoIDs = append(memoIDs, fmt.Sprintf("'%d'", m.ID)) + contentIDs = append(contentIDs, fmt.Sprintf("%s%s", MemoNamePrefix, m.UID)) + memoIDs = append(memoIDs, m.ID) } // REACTIONS - reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ - Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNames, ", "))}, - }) + reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to list reactions") } @@ -222,9 +220,7 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq } // ATTACHMENTS - attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ - Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDs, ", "))}, - }) + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDs}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to list attachments") } @@ -630,30 +626,26 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM return response, nil } - memoRelationIDs := make([]string, 0, len(memoRelations)) + memoRelationIDs := make([]int32, 0, len(memoRelations)) for _, m := range memoRelations { - memoRelationIDs = append(memoRelationIDs, fmt.Sprintf("%d", m.MemoID)) + memoRelationIDs = append(memoRelationIDs, m.MemoID) } - memos, err := s.Store.ListMemos(ctx, &store.FindMemo{ - Filters: []string{fmt.Sprintf("id in [%s]", strings.Join(memoRelationIDs, ", "))}, - }) + memos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: memoRelationIDs}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to list memos") } memoIDToNameMap := make(map[int32]string) - memoNamesForQuery := make([]string, 0, len(memos)) - memoIDsForQuery := make([]string, 0, len(memos)) + contentIDs := make([]string, 0, len(memos)) + memoIDsForAttachments := make([]int32, 0, len(memos)) for _, memo := range memos { memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) memoIDToNameMap[memo.ID] = memoName - memoNamesForQuery = append(memoNamesForQuery, fmt.Sprintf("'%s'", memoName)) - memoIDsForQuery = append(memoIDsForQuery, fmt.Sprintf("'%d'", memo.ID)) + contentIDs = append(contentIDs, memoName) + memoIDsForAttachments = append(memoIDsForAttachments, memo.ID) } - reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ - Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNamesForQuery, ", "))}, - }) + reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to list reactions") } @@ -663,9 +655,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction) } - attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ - Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDsForQuery, ", "))}, - }) + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDsForAttachments}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to list attachments") } diff --git a/server/router/api/v1/shortcut_service.go b/server/router/api/v1/shortcut_service.go index 46d1b3849..0d4467dd6 100644 --- a/server/router/api/v1/shortcut_service.go +++ b/server/router/api/v1/shortcut_service.go @@ -319,35 +319,30 @@ func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteS return &emptypb.Empty{}, nil } -func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error { +func (s *APIV1Service) validateFilter(ctx context.Context, filterStr string) error { if filterStr == "" { return errors.New("filter cannot be empty") } - // Validate the filter. - parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...) - if err != nil { - return errors.Wrap(err, "failed to parse filter") - } - convertCtx := filter.NewConvertContext() - // Determine the dialect based on the actual database driver - var dialect filter.SQLDialect + engine, err := filter.DefaultEngine() + if err != nil { + return err + } + + var dialect filter.DialectName switch s.Profile.Driver { case "sqlite": - dialect = &filter.SQLiteDialect{} + dialect = filter.DialectSQLite case "mysql": - dialect = &filter.MySQLDialect{} + dialect = filter.DialectMySQL case "postgres": - dialect = &filter.PostgreSQLDialect{} + dialect = filter.DialectPostgres default: - // Default to SQLite for unknown drivers - dialect = &filter.SQLiteDialect{} + dialect = filter.DialectSQLite } - converter := filter.NewCommonSQLConverter(dialect) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - if err != nil { - return errors.Wrap(err, "failed to convert filter to SQL") + if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil { + return errors.Wrap(err, "failed to compile filter") } return nil } diff --git a/server/router/api/v1/user_filter_test.go b/server/router/api/v1/user_filter_test.go deleted file mode 100644 index 94c01855f..000000000 --- a/server/router/api/v1/user_filter_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package v1 - -import ( - "testing" - - "github.com/usememos/memos/plugin/filter" -) - -func TestUserFilterValidation(t *testing.T) { - testCases := []struct { - name string - filter string - expectErr bool - }{ - { - name: "valid username filter with equals", - filter: `username == "testuser"`, - expectErr: false, - }, - { - name: "valid username filter with contains", - filter: `username.contains("admin")`, - expectErr: false, - }, - { - name: "invalid filter - unknown field", - filter: `invalid_field == "test"`, - expectErr: true, - }, - { - name: "empty filter", - filter: "", - expectErr: true, - }, - { - name: "invalid syntax", - filter: `username ==`, - expectErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test the filter parsing directly - _, err := filter.Parse(tc.filter, filter.UserFilterCELAttributes...) - - if tc.expectErr && err == nil { - t.Errorf("Expected error for filter %q, but got none", tc.filter) - } - if !tc.expectErr && err != nil { - t.Errorf("Expected no error for filter %q, but got: %v", tc.filter, err) - } - }) - } -} - -func TestUserFilterCELAttributes(t *testing.T) { - // Test that our UserFilterCELAttributes contains the username variable - expectedAttributes := map[string]bool{ - "username": true, - } - - // This is a basic test to ensure the attributes are defined - // In a real test, you would create a CEL environment and verify the attributes - for attrName := range expectedAttributes { - t.Logf("Expected attribute %s should be available in UserFilterCELAttributes", attrName) - } -} diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index b1f0a96ec..1da7191c1 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -25,7 +25,6 @@ import ( "github.com/usememos/memos/internal/base" "github.com/usememos/memos/internal/util" - "github.com/usememos/memos/plugin/filter" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" @@ -49,7 +48,6 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq if err := s.validateUserFilter(ctx, request.Filter); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) } - userFind.Filters = append(userFind.Filters, request.Filter) } users, err := s.Store.ListUsers(ctx, userFind) @@ -1368,34 +1366,8 @@ func extractWebhookIDFromName(name string) string { // validateUserFilter validates the user filter string. func (s *APIV1Service) validateUserFilter(_ context.Context, filterStr string) error { - if filterStr == "" { - return errors.New("filter cannot be empty") - } - // Validate the filter. - parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...) - if err != nil { - return errors.Wrap(err, "failed to parse filter") - } - convertCtx := filter.NewConvertContext() - - // Determine the dialect based on the actual database driver - var dialect filter.SQLDialect - switch s.Profile.Driver { - case "sqlite": - dialect = &filter.SQLiteDialect{} - case "mysql": - dialect = &filter.MySQLDialect{} - case "postgres": - dialect = &filter.PostgreSQLDialect{} - default: - // Default to SQLite for unknown drivers - dialect = &filter.SQLiteDialect{} - } - - converter := filter.NewUserSQLConverter(dialect) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - if err != nil { - return errors.Wrap(err, "failed to convert filter to SQL") + if strings.TrimSpace(filterStr) != "" { + return errors.New("user filters are not supported") } return nil } diff --git a/store/attachment.go b/store/attachment.go index 5b57d8c31..7146f75ca 100644 --- a/store/attachment.go +++ b/store/attachment.go @@ -48,11 +48,11 @@ type FindAttachment struct { Filename *string FilenameSearch *string MemoID *int32 + MemoIDList []int32 HasRelatedMemo bool StorageType *storepb.AttachmentStorageType Limit *int Offset *int - Filters []string } type UpdateAttachment struct { diff --git a/store/db/mysql/attachment.go b/store/db/mysql/attachment.go index 525bc1573..6eb59338d 100644 --- a/store/db/mysql/attachment.go +++ b/store/db/mysql/attachment.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" - "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -49,26 +48,6 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } - } - if v := find.ID; v != nil { where, args = append(where, "`resource`.`id` = ?"), append(args, *v) } @@ -87,6 +66,16 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ if v := find.MemoID; v != nil { where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v) } + if len(find.MemoIDList) > 0 { + placeholders := make([]string, 0, len(find.MemoIDList)) + for range find.MemoIDList { + placeholders = append(placeholders, "?") + } + where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")") + for _, id := range find.MemoIDList { + args = append(args, id) + } + } if find.HasRelatedMemo { where = append(where, "`resource`.`memo_id` IS NOT NULL") } diff --git a/store/db/mysql/attachment_filter_test.go b/store/db/mysql/attachment_filter_test.go deleted file mode 100644 index ea43b8bb0..000000000 --- a/store/db/mysql/attachment_filter_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package mysql - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestAttachmentConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`, - want: "`resource`.`memo_id` IN (?)", - args: []any{"5atZAj8GcvkSuUA3X2KLaY"}, - }, - { - filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`, - want: "`resource`.`memo_id` IN (?,?)", - args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"}, - }, - } - - for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...) - require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) - } -} diff --git a/store/db/mysql/memo.go b/store/db/mysql/memo.go index 354aab545..5eea60b30 100644 --- a/store/db/mysql/memo.go +++ b/store/db/mysql/memo.go @@ -50,31 +50,39 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { where, having, args := []string{"1 = 1"}, []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } + engine, err := filter.DefaultEngine() + if err != nil { + return nil, err + } + if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectMySQL, &where, &args); err != nil { + return nil, err } if v := find.ID; v != nil { where, args = append(where, "`memo`.`id` = ?"), append(args, *v) } + if len(find.IDList) > 0 { + placeholders := make([]string, 0, len(find.IDList)) + for range find.IDList { + placeholders = append(placeholders, "?") + } + where = append(where, "`memo`.`id` IN ("+strings.Join(placeholders, ",")+")") + for _, id := range find.IDList { + args = append(args, id) + } + } if v := find.UID; v != nil { where, args = append(where, "`memo`.`uid` = ?"), append(args, *v) } + if len(find.UIDList) > 0 { + placeholders := make([]string, 0, len(find.UIDList)) + for range find.UIDList { + placeholders = append(placeholders, "?") + } + where = append(where, "`memo`.`uid` IN ("+strings.Join(placeholders, ",")+")") + for _, uid := range find.UIDList { + args = append(args, uid) + } + } if v := find.CreatorID; v != nil { where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v) } diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go index 1c87d0a82..52fc697b9 100644 --- a/store/db/mysql/memo_filter_test.go +++ b/store/db/mysql/memo_filter_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "testing" "time" @@ -147,14 +148,15 @@ func TestConvertExprToSQL(t *testing.T) { }, } + engine, err := filter.DefaultEngine() + require.NoError(t, err) + for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) + stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{ + Dialect: filter.DialectMySQL, + }) require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) + require.Equal(t, tt.want, stmt.SQL) + require.Equal(t, tt.args, stmt.Args) } } diff --git a/store/db/mysql/memo_relation.go b/store/db/mysql/memo_relation.go index a57bda8eb..3116903e0 100644 --- a/store/db/mysql/memo_relation.go +++ b/store/db/mysql/memo_relation.go @@ -43,23 +43,21 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation where, args = append(where, "`type` = ?"), append(args, find.Type) } if find.MemoFilter != nil { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...) + engine, err := filter.DefaultEngine() if err != nil { return nil, err } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{ + Dialect: filter.DialectMySQL, + PlaceholderOffset: 0, + }) + if err != nil { return nil, err } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition)) - where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition)) - args = append(args, append(convertCtx.Args, convertCtx.Args...)...) + if stmt.SQL != "" { + where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL)) + where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL)) + args = append(args, append(stmt.Args, stmt.Args...)...) } } diff --git a/store/db/mysql/reaction.go b/store/db/mysql/reaction.go index b2878b4e4..40c15ef77 100644 --- a/store/db/mysql/reaction.go +++ b/store/db/mysql/reaction.go @@ -2,12 +2,10 @@ package mysql import ( "context" - "fmt" "strings" "github.com/pkg/errors" - "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -37,27 +35,7 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store } func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) { - where, args := []string{"1 = 1"}, []interface{}{} - - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } - } + where, args := []string{"1 = 1"}, []any{} if find.ID != nil { where, args = append(where, "`id` = ?"), append(args, *find.ID) @@ -68,6 +46,14 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st if find.ContentID != nil { where, args = append(where, "`content_id` = ?"), append(args, *find.ContentID) } + if len(find.ContentIDList) > 0 { + placeholders := make([]string, 0, len(find.ContentIDList)) + for _, id := range find.ContentIDList { + placeholders = append(placeholders, "?") + args = append(args, id) + } + where = append(where, "`content_id` IN ("+strings.Join(placeholders, ",")+")") + } rows, err := d.db.QueryContext(ctx, ` SELECT diff --git a/store/db/mysql/reaction_filter_test.go b/store/db/mysql/reaction_filter_test.go deleted file mode 100644 index 1ea4621df..000000000 --- a/store/db/mysql/reaction_filter_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package mysql - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestReactionConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`, - want: "`reaction`.`content_id` IN (?)", - args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"}, - }, - { - filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`, - want: "`reaction`.`content_id` IN (?,?)", - args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"}, - }, - } - - for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...) - require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) - } -} diff --git a/store/db/mysql/user.go b/store/db/mysql/user.go index e2bddfd0d..4403e07ff 100644 --- a/store/db/mysql/user.go +++ b/store/db/mysql/user.go @@ -7,7 +7,6 @@ import ( "github.com/pkg/errors" - "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -85,24 +84,8 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewUserSQLConverter(&filter.MySQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } + if len(find.Filters) > 0 { + return nil, errors.Errorf("user filters are not supported") } if v := find.ID; v != nil { diff --git a/store/db/postgres/attachment.go b/store/db/postgres/attachment.go index 90311c92d..44362cc00 100644 --- a/store/db/postgres/attachment.go +++ b/store/db/postgres/attachment.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" - "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -40,26 +39,6 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } - } - if v := find.ID; v != nil { where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v) } @@ -78,6 +57,16 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ if v := find.MemoID; v != nil { where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v) } + if len(find.MemoIDList) > 0 { + holders := make([]string, 0, len(find.MemoIDList)) + for range find.MemoIDList { + holders = append(holders, placeholder(len(args)+1)) + } + where = append(where, "resource.memo_id IN ("+strings.Join(holders, ", ")+")") + for _, id := range find.MemoIDList { + args = append(args, id) + } + } if find.HasRelatedMemo { where = append(where, "resource.memo_id IS NOT NULL") } diff --git a/store/db/postgres/attachment_filter_test.go b/store/db/postgres/attachment_filter_test.go deleted file mode 100644 index 788962d68..000000000 --- a/store/db/postgres/attachment_filter_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package postgres - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestAttachmentConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`, - want: "resource.memo_id IN ($1)", - args: []any{"5atZAj8GcvkSuUA3X2KLaY"}, - }, - { - filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`, - want: "resource.memo_id IN ($1,$2)", - args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"}, - }, - } - - for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...) - require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) - } -} diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index 0400165a7..0d35d6af0 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -41,32 +41,39 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - convertCtx.ArgsOffset = len(args) - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args)) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } + engine, err := filter.DefaultEngine() + if err != nil { + return nil, err + } + if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectPostgres, &where, &args); err != nil { + return nil, err } if v := find.ID; v != nil { where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v) } + if len(find.IDList) > 0 { + holders := make([]string, 0, len(find.IDList)) + for range find.IDList { + holders = append(holders, placeholder(len(args)+1)) + } + where = append(where, "memo.id IN ("+strings.Join(holders, ", ")+")") + for _, id := range find.IDList { + args = append(args, id) + } + } if v := find.UID; v != nil { where, args = append(where, "memo.uid = "+placeholder(len(args)+1)), append(args, *v) } + if len(find.UIDList) > 0 { + holders := make([]string, 0, len(find.UIDList)) + for range find.UIDList { + holders = append(holders, placeholder(len(args)+1)) + } + where = append(where, "memo.uid IN ("+strings.Join(holders, ", ")+")") + for _, uid := range find.UIDList { + args = append(args, uid) + } + } if v := find.CreatorID; v != nil { where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v) } diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go index fe32ef7e7..d9fe2636e 100644 --- a/store/db/postgres/memo_filter_test.go +++ b/store/db/postgres/memo_filter_test.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "testing" "time" @@ -147,14 +148,13 @@ func TestConvertExprToSQL(t *testing.T) { }, } + engine, err := filter.DefaultEngine() + require.NoError(t, err) + for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) + stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectPostgres}) require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args)) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) + require.Equal(t, tt.want, stmt.SQL) + require.Equal(t, tt.args, stmt.Args) } } diff --git a/store/db/postgres/memo_relation.go b/store/db/postgres/memo_relation.go index 5cc1cbd07..881291b8a 100644 --- a/store/db/postgres/memo_relation.go +++ b/store/db/postgres/memo_relation.go @@ -49,24 +49,32 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type) } if find.MemoFilter != nil { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...) + engine, err := filter.DefaultEngine() if err != nil { return nil, err } - convertCtx := filter.NewConvertContext() - convertCtx.ArgsOffset = len(args) - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args)) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{ + Dialect: filter.DialectPostgres, + PlaceholderOffset: len(args), + }) + if err != nil { return nil, err } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition)) - where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition)) - args = append(args, convertCtx.Args...) + if stmt.SQL != "" { + where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL)) + args = append(args, stmt.Args...) + + stmtRelated, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{ + Dialect: filter.DialectPostgres, + PlaceholderOffset: len(args), + }) + if err != nil { + return nil, err + } + if stmtRelated.SQL != "" { + where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmtRelated.SQL)) + args = append(args, stmtRelated.Args...) + } } } diff --git a/store/db/postgres/reaction.go b/store/db/postgres/reaction.go index 4bfb9f7df..3ff6354cb 100644 --- a/store/db/postgres/reaction.go +++ b/store/db/postgres/reaction.go @@ -2,10 +2,8 @@ package postgres import ( "context" - "fmt" "strings" - "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -25,27 +23,7 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store } func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) { - where, args := []string{"1 = 1"}, []interface{}{} - - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } - } + where, args := []string{"1 = 1"}, []any{} if find.ID != nil { where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID) @@ -56,6 +34,18 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st if find.ContentID != nil { where, args = append(where, "content_id = "+placeholder(len(args)+1)), append(args, *find.ContentID) } + if len(find.ContentIDList) > 0 { + holders := make([]string, 0, len(find.ContentIDList)) + for range find.ContentIDList { + holders = append(holders, placeholder(len(args)+1)) + } + if len(holders) > 0 { + where = append(where, "content_id IN ("+strings.Join(holders, ", ")+")") + for _, id := range find.ContentIDList { + args = append(args, id) + } + } + } rows, err := d.db.QueryContext(ctx, ` SELECT diff --git a/store/db/postgres/reaction_filter_test.go b/store/db/postgres/reaction_filter_test.go deleted file mode 100644 index 05f801699..000000000 --- a/store/db/postgres/reaction_filter_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package postgres - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestReactionConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`, - want: "reaction.content_id IN ($1)", - args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"}, - }, - { - filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`, - want: "reaction.content_id IN ($1,$2)", - args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"}, - }, - } - - for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...) - require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) - } -} diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go index b582cff24..0be4aa8b8 100644 --- a/store/db/postgres/user.go +++ b/store/db/postgres/user.go @@ -5,7 +5,8 @@ import ( "fmt" "strings" - "github.com/usememos/memos/plugin/filter" + "github.com/pkg/errors" + "github.com/usememos/memos/store" ) @@ -86,24 +87,8 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewUserSQLConverter(&filter.PostgreSQLDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } + if len(find.Filters) > 0 { + return nil, errors.Errorf("user filters are not supported") } if v := find.ID; v != nil { diff --git a/store/db/sqlite/attachment.go b/store/db/sqlite/attachment.go index 34aaac0b7..c35257338 100644 --- a/store/db/sqlite/attachment.go +++ b/store/db/sqlite/attachment.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" - "github.com/usememos/memos/plugin/filter" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -42,26 +41,6 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } - } - if v := find.ID; v != nil { where, args = append(where, "`resource`.`id` = ?"), append(args, *v) } @@ -80,6 +59,16 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ if v := find.MemoID; v != nil { where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v) } + if len(find.MemoIDList) > 0 { + placeholders := make([]string, 0, len(find.MemoIDList)) + for range find.MemoIDList { + placeholders = append(placeholders, "?") + } + where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")") + for _, id := range find.MemoIDList { + args = append(args, id) + } + } if find.HasRelatedMemo { where = append(where, "`resource`.`memo_id` IS NOT NULL") } diff --git a/store/db/sqlite/attachment_filter_test.go b/store/db/sqlite/attachment_filter_test.go deleted file mode 100644 index efe7b0c6f..000000000 --- a/store/db/sqlite/attachment_filter_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package sqlite - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestAttachmentConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`, - want: "`resource`.`memo_id` IN (?)", - args: []any{"5atZAj8GcvkSuUA3X2KLaY"}, - }, - { - filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`, - want: "`resource`.`memo_id` IN (?,?)", - args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"}, - }, - } - - for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...) - require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) - } -} diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go index 74e0b87b6..ef2f103f3 100644 --- a/store/db/sqlite/memo.go +++ b/store/db/sqlite/memo.go @@ -42,31 +42,39 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } + engine, err := filter.DefaultEngine() + if err != nil { + return nil, err + } + if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectSQLite, &where, &args); err != nil { + return nil, err } if v := find.ID; v != nil { where, args = append(where, "`memo`.`id` = ?"), append(args, *v) } + if len(find.IDList) > 0 { + placeholders := make([]string, 0, len(find.IDList)) + for range find.IDList { + placeholders = append(placeholders, "?") + } + where = append(where, "`memo`.`id` IN ("+strings.Join(placeholders, ",")+")") + for _, id := range find.IDList { + args = append(args, id) + } + } if v := find.UID; v != nil { where, args = append(where, "`memo`.`uid` = ?"), append(args, *v) } + if len(find.UIDList) > 0 { + placeholders := make([]string, 0, len(find.UIDList)) + for range find.UIDList { + placeholders = append(placeholders, "?") + } + where = append(where, "`memo`.`uid` IN ("+strings.Join(placeholders, ",")+")") + for _, uid := range find.UIDList { + args = append(args, uid) + } + } if v := find.CreatorID; v != nil { where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v) } diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index 6c67daab7..cea5ab558 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "testing" "time" @@ -152,14 +153,13 @@ func TestConvertExprToSQL(t *testing.T) { }, } + engine, err := filter.DefaultEngine() + require.NoError(t, err) + for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) + stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectSQLite}) require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) + require.Equal(t, tt.want, stmt.SQL) + require.Equal(t, tt.args, stmt.Args) } } diff --git a/store/db/sqlite/memo_relation.go b/store/db/sqlite/memo_relation.go index 56182e7f4..3e63c7002 100644 --- a/store/db/sqlite/memo_relation.go +++ b/store/db/sqlite/memo_relation.go @@ -49,23 +49,18 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation where, args = append(where, "type = ?"), append(args, find.Type) } if find.MemoFilter != nil { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...) + engine, err := filter.DefaultEngine() if err != nil { return nil, err } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{Dialect: filter.DialectSQLite}) + if err != nil { return nil, err } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition)) - where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition)) - args = append(args, append(convertCtx.Args, convertCtx.Args...)...) + if stmt.SQL != "" { + where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL)) + where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL)) + args = append(args, append(stmt.Args, stmt.Args...)...) } } diff --git a/store/db/sqlite/reaction.go b/store/db/sqlite/reaction.go index 10f86bbd8..a6f87cdc5 100644 --- a/store/db/sqlite/reaction.go +++ b/store/db/sqlite/reaction.go @@ -2,10 +2,8 @@ package sqlite import ( "context" - "fmt" "strings" - "github.com/usememos/memos/plugin/filter" "github.com/usememos/memos/store" ) @@ -26,27 +24,7 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store } func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) { - where, args := []string{"1 = 1"}, []interface{}{} - - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } - } + where, args := []string{"1 = 1"}, []any{} if find.ID != nil { where, args = append(where, "id = ?"), append(args, *find.ID) @@ -57,6 +35,18 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st if find.ContentID != nil { where, args = append(where, "content_id = ?"), append(args, *find.ContentID) } + if len(find.ContentIDList) > 0 { + placeholders := make([]string, 0, len(find.ContentIDList)) + for range find.ContentIDList { + placeholders = append(placeholders, "?") + } + if len(placeholders) > 0 { + where = append(where, "content_id IN ("+strings.Join(placeholders, ",")+")") + for _, id := range find.ContentIDList { + args = append(args, id) + } + } + } rows, err := d.db.QueryContext(ctx, ` SELECT diff --git a/store/db/sqlite/reaction_filter_test.go b/store/db/sqlite/reaction_filter_test.go deleted file mode 100644 index d07f7cbbb..000000000 --- a/store/db/sqlite/reaction_filter_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package sqlite - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestReactionConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`, - want: "`reaction`.`content_id` IN (?)", - args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"}, - }, - { - filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`, - want: "`reaction`.`content_id` IN (?,?)", - args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"}, - }, - } - - for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...) - require.NoError(t, err) - convertCtx := filter.NewConvertContext() - converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{}) - err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - require.NoError(t, err) - require.Equal(t, tt.want, convertCtx.Buffer.String()) - require.Equal(t, tt.args, convertCtx.Args) - } -} diff --git a/store/db/sqlite/user.go b/store/db/sqlite/user.go index 2ebf64dd7..b5cb906bd 100644 --- a/store/db/sqlite/user.go +++ b/store/db/sqlite/user.go @@ -5,7 +5,8 @@ import ( "fmt" "strings" - "github.com/usememos/memos/plugin/filter" + "github.com/pkg/errors" + "github.com/usememos/memos/store" ) @@ -87,24 +88,8 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { where, args := []string{"1 = 1"}, []any{} - for _, filterStr := range find.Filters { - // Parse filter string and return the parsed expression. - // The filter string should be a CEL expression. - parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...) - if err != nil { - return nil, err - } - convertCtx := filter.NewConvertContext() - // ConvertExprToSQL converts the parsed expression to a SQL condition string. - converter := filter.NewUserSQLConverter(&filter.SQLiteDialect{}) - if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { - return nil, err - } - condition := convertCtx.Buffer.String() - if condition != "" { - where = append(where, fmt.Sprintf("(%s)", condition)) - args = append(args, convertCtx.Args...) - } + if len(find.Filters) > 0 { + return nil, errors.Errorf("user filters are not supported") } if v := find.ID; v != nil { diff --git a/store/memo.go b/store/memo.go index 9d075339a..a32a9320e 100644 --- a/store/memo.go +++ b/store/memo.go @@ -60,6 +60,9 @@ type FindMemo struct { ID *int32 UID *string + IDList []int32 + UIDList []string + // Standard fields RowStatus *RowStatus CreatorID *int32 diff --git a/store/reaction.go b/store/reaction.go index 7354cd9e0..a10093128 100644 --- a/store/reaction.go +++ b/store/reaction.go @@ -14,10 +14,10 @@ type Reaction struct { } type FindReaction struct { - ID *int32 - CreatorID *int32 - ContentID *string - Filters []string + ID *int32 + CreatorID *int32 + ContentID *string + ContentIDList []string } type DeleteReaction struct {