mirror of https://github.com/usememos/memos.git
refactor: memo filter
This commit is contained in:
parent
52a5ca2ef4
commit
778a5eb184
|
|
@ -0,0 +1,448 @@
|
||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CommonSQLConverter handles the common CEL to SQL conversion logic
|
||||||
|
type CommonSQLConverter struct {
|
||||||
|
dialect SQLDialect
|
||||||
|
paramIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommonSQLConverter creates a new converter with the specified dialect
|
||||||
|
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
|
||||||
|
return &CommonSQLConverter{
|
||||||
|
dialect: dialect,
|
||||||
|
paramIndex: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect
|
||||||
|
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
|
switch v.CallExpr.Function {
|
||||||
|
case "_||_", "_&&_":
|
||||||
|
return c.handleLogicalOperator(ctx, v.CallExpr)
|
||||||
|
case "!_":
|
||||||
|
return c.handleNotOperator(ctx, v.CallExpr)
|
||||||
|
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||||
|
return c.handleComparisonOperator(ctx, v.CallExpr)
|
||||||
|
case "@in":
|
||||||
|
return c.handleInOperator(ctx, v.CallExpr)
|
||||||
|
case "contains":
|
||||||
|
return c.handleContainsOperator(ctx, v.CallExpr)
|
||||||
|
}
|
||||||
|
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||||
|
return c.handleIdentifier(ctx, v.IdentExpr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||||
|
if len(callExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
operator := "AND"
|
||||||
|
if callExpr.Function == "_||_" {
|
||||||
|
operator = "OR"
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||||
|
if len(callExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||||
|
if len(callExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the left side is a function call like size(tags)
|
||||||
|
if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
|
if leftCallExpr.CallExpr.Function == "size" {
|
||||||
|
return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := GetExprValue(callExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
operator := c.getComparisonOperator(callExpr.Function)
|
||||||
|
|
||||||
|
switch identifier {
|
||||||
|
case "created_ts", "updated_ts":
|
||||||
|
return c.handleTimestampComparison(ctx, identifier, operator, value)
|
||||||
|
case "visibility", "content":
|
||||||
|
return c.handleStringComparison(ctx, identifier, operator, value)
|
||||||
|
case "creator_id":
|
||||||
|
return c.handleIntComparison(ctx, identifier, operator, value)
|
||||||
|
case "has_task_list":
|
||||||
|
return c.handleBooleanComparison(ctx, identifier, operator, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
|
||||||
|
if len(sizeCall.Args) != 1 {
|
||||||
|
return errors.New("size function requires exactly one argument")
|
||||||
|
}
|
||||||
|
|
||||||
|
identifier, err := GetIdentExprName(sizeCall.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier != "tags" {
|
||||||
|
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := GetExprValue(callExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
valueInt, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("size comparison value must be an integer")
|
||||||
|
}
|
||||||
|
|
||||||
|
operator := c.getComparisonOperator(callExpr.Function)
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||||
|
c.dialect.GetJSONArrayLength("$.tags"),
|
||||||
|
operator,
|
||||||
|
c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
c.paramIndex++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||||
|
if len(callExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is "element in collection" syntax
|
||||||
|
if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil {
|
||||||
|
if identifier == "tags" {
|
||||||
|
return c.handleElementInTags(ctx, callExpr.Args[0])
|
||||||
|
}
|
||||||
|
return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original logic for "identifier in [list]" syntax
|
||||||
|
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := []any{}
|
||||||
|
for _, element := range callExpr.Args[1].GetListExpr().Elements {
|
||||||
|
value, err := GetConstValue(element)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "tag" {
|
||||||
|
return c.handleTagInList(ctx, values)
|
||||||
|
} else if identifier == "visibility" {
|
||||||
|
return c.handleVisibilityInList(ctx, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error {
|
||||||
|
element, err := GetConstValue(elementExpr)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use dialect-specific JSON contains logic
|
||||||
|
sqlExpr := c.dialect.GetJSONContains("$.tags", "element")
|
||||||
|
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For SQLite, we need a different approach since it uses LIKE
|
||||||
|
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||||
|
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
|
||||||
|
} else {
|
||||||
|
ctx.Args = append(ctx.Args, element)
|
||||||
|
}
|
||||||
|
c.paramIndex++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error {
|
||||||
|
subconditions := []string{}
|
||||||
|
args := []any{}
|
||||||
|
|
||||||
|
for _, v := range values {
|
||||||
|
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||||
|
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
|
||||||
|
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
||||||
|
} else {
|
||||||
|
subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element"))
|
||||||
|
args = append(args, v)
|
||||||
|
}
|
||||||
|
c.paramIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(subconditions) == 1 {
|
||||||
|
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, args...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error {
|
||||||
|
placeholders := []string{}
|
||||||
|
for range values {
|
||||||
|
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||||
|
c.paramIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
tablePrefix := c.dialect.GetTablePrefix()
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||||
|
if len(callExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
identifier, err := GetIdentExprName(callExpr.Target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier != "content" {
|
||||||
|
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
arg, err := GetConstValue(callExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tablePrefix := c.dialect.GetTablePrefix()
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
|
c.paramIndex++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
|
||||||
|
identifier := identExpr.GetName()
|
||||||
|
|
||||||
|
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier %s", identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "pinned" {
|
||||||
|
tablePrefix := c.dialect.GetTablePrefix()
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if identifier == "has_task_list" {
|
||||||
|
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||||
|
valueInt, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid integer timestamp value")
|
||||||
|
}
|
||||||
|
|
||||||
|
timestampField := c.dialect.GetTimestampComparison(field)
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
c.paramIndex++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||||
|
if operator != "=" && operator != "!=" {
|
||||||
|
return errors.Errorf("invalid operator for %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
valueStr, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid string value")
|
||||||
|
}
|
||||||
|
|
||||||
|
tablePrefix := c.dialect.GetTablePrefix()
|
||||||
|
fieldName := field
|
||||||
|
if field == "visibility" {
|
||||||
|
fieldName = "`visibility`"
|
||||||
|
} else if field == "content" {
|
||||||
|
fieldName = "`content`"
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, valueStr)
|
||||||
|
c.paramIndex++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||||
|
if operator != "=" && operator != "!=" {
|
||||||
|
return errors.Errorf("invalid operator for %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
valueInt, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid int value")
|
||||||
|
}
|
||||||
|
|
||||||
|
tablePrefix := c.dialect.GetTablePrefix()
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
c.paramIndex++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||||
|
if operator != "=" && operator != "!=" {
|
||||||
|
return errors.Errorf("invalid operator for %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
valueBool, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid boolean value for has_task_list")
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool)
|
||||||
|
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For dialects that need parameters (PostgreSQL)
|
||||||
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
|
ctx.Args = append(ctx.Args, valueBool)
|
||||||
|
c.paramIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) getComparisonOperator(function string) string {
|
||||||
|
switch function {
|
||||||
|
case "_==_":
|
||||||
|
return "="
|
||||||
|
case "_!=_":
|
||||||
|
return "!="
|
||||||
|
case "_<_":
|
||||||
|
return "<"
|
||||||
|
case "_>_":
|
||||||
|
return ">"
|
||||||
|
case "_<=_":
|
||||||
|
return "<="
|
||||||
|
case "_>=_":
|
||||||
|
return ">="
|
||||||
|
default:
|
||||||
|
return "="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,212 @@
|
||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLDialect defines database-specific SQL generation methods
|
||||||
|
type SQLDialect interface {
|
||||||
|
// Basic field access
|
||||||
|
GetTablePrefix() string
|
||||||
|
GetParameterPlaceholder(index int) string
|
||||||
|
|
||||||
|
// JSON operations
|
||||||
|
GetJSONExtract(path string) string
|
||||||
|
GetJSONArrayLength(path string) string
|
||||||
|
GetJSONContains(path, element string) string
|
||||||
|
GetJSONLike(path, pattern string) string
|
||||||
|
|
||||||
|
// Boolean operations
|
||||||
|
GetBooleanValue(value bool) interface{}
|
||||||
|
GetBooleanComparison(path string, value bool) string
|
||||||
|
GetBooleanCheck(path string) string
|
||||||
|
|
||||||
|
// Timestamp operations
|
||||||
|
GetTimestampComparison(field string) string
|
||||||
|
GetCurrentTimestamp() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseType represents the type of database
|
||||||
|
type DatabaseType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SQLite DatabaseType = "sqlite"
|
||||||
|
MySQL DatabaseType = "mysql"
|
||||||
|
PostgreSQL DatabaseType = "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDialect returns the appropriate dialect for the database type
|
||||||
|
func GetDialect(dbType DatabaseType) SQLDialect {
|
||||||
|
switch dbType {
|
||||||
|
case SQLite:
|
||||||
|
return &SQLiteDialect{}
|
||||||
|
case MySQL:
|
||||||
|
return &MySQLDialect{}
|
||||||
|
case PostgreSQL:
|
||||||
|
return &PostgreSQLDialect{}
|
||||||
|
default:
|
||||||
|
return &SQLiteDialect{} // default fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLiteDialect implements SQLDialect for SQLite
|
||||||
|
type SQLiteDialect struct{}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetTablePrefix() string {
|
||||||
|
return "`memo`"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetParameterPlaceholder(index int) string {
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
||||||
|
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
||||||
|
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetJSONContains(path, element string) string {
|
||||||
|
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetJSONLike(path, pattern string) string {
|
||||||
|
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetBooleanValue(value bool) interface{} {
|
||||||
|
if value {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
|
||||||
|
return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
||||||
|
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
||||||
|
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) GetCurrentTimestamp() string {
|
||||||
|
return "strftime('%s', 'now')"
|
||||||
|
}
|
||||||
|
|
||||||
|
// MySQLDialect implements SQLDialect for MySQL
|
||||||
|
type MySQLDialect struct{}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetTablePrefix() string {
|
||||||
|
return "`memo`"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetParameterPlaceholder(index int) string {
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
||||||
|
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
||||||
|
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetJSONContains(path, element string) string {
|
||||||
|
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetJSONLike(path, pattern string) string {
|
||||||
|
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetBooleanValue(value bool) interface{} {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
|
||||||
|
boolStr := "false"
|
||||||
|
if value {
|
||||||
|
boolStr = "true"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
||||||
|
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
||||||
|
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MySQLDialect) GetCurrentTimestamp() string {
|
||||||
|
return "UNIX_TIMESTAMP()"
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostgreSQLDialect implements SQLDialect for PostgreSQL
|
||||||
|
type PostgreSQLDialect struct{}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetTablePrefix() string {
|
||||||
|
return "memo"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||||
|
return fmt.Sprintf("$%d", index)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||||
|
// Convert $.property.hasTaskList to payload->'property'->>'hasTaskList'
|
||||||
|
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||||
|
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
|
||||||
|
for i, part := range parts {
|
||||||
|
if i == len(parts)-1 {
|
||||||
|
result += fmt.Sprintf("->>'%s'", part)
|
||||||
|
} else {
|
||||||
|
result += fmt.Sprintf("->'%s'", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
||||||
|
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||||
|
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetJSONContains(path, element string) string {
|
||||||
|
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||||
|
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetJSONLike(path, pattern string) string {
|
||||||
|
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||||
|
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetBooleanComparison(path string, value bool) string {
|
||||||
|
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
||||||
|
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
||||||
|
return fmt.Sprintf("EXTRACT(EPOCH FROM %s.%s)", d.GetTablePrefix(), field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PostgreSQLDialect) GetCurrentTimestamp() string {
|
||||||
|
return "EXTRACT(EPOCH FROM NOW())"
|
||||||
|
}
|
||||||
|
|
@ -18,6 +18,7 @@ var MemoFilterCELAttributes = []cel.EnvOption{
|
||||||
cel.Variable("updated_ts", cel.IntType),
|
cel.Variable("updated_ts", cel.IntType),
|
||||||
cel.Variable("pinned", cel.BoolType),
|
cel.Variable("pinned", cel.BoolType),
|
||||||
cel.Variable("tag", cel.StringType),
|
cel.Variable("tag", cel.StringType),
|
||||||
|
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||||
cel.Variable("visibility", cel.StringType),
|
cel.Variable("visibility", cel.StringType),
|
||||||
cel.Variable("has_task_list", cel.BoolType),
|
cel.Variable("has_task_list", cel.BoolType),
|
||||||
// Current timestamp function.
|
// Current timestamp function.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLTemplate holds database-specific SQL fragments
|
||||||
|
type SQLTemplate struct {
|
||||||
|
SQLite string
|
||||||
|
MySQL string
|
||||||
|
PostgreSQL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TemplateDBType represents the database type for templates
|
||||||
|
type TemplateDBType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SQLiteTemplate TemplateDBType = "sqlite"
|
||||||
|
MySQLTemplate TemplateDBType = "mysql"
|
||||||
|
PostgreSQLTemplate TemplateDBType = "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLTemplates contains common SQL patterns for different databases
|
||||||
|
var SQLTemplates = map[string]SQLTemplate{
|
||||||
|
"json_extract": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||||
|
PostgreSQL: "memo.payload%s",
|
||||||
|
},
|
||||||
|
"json_array_length": {
|
||||||
|
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||||
|
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||||
|
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
|
||||||
|
},
|
||||||
|
"json_contains_element": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||||
|
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||||
|
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||||
|
},
|
||||||
|
"json_contains_tag": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||||
|
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||||
|
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||||
|
},
|
||||||
|
"boolean_true": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||||
|
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
|
||||||
|
},
|
||||||
|
"boolean_false": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||||
|
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
|
||||||
|
},
|
||||||
|
"boolean_not_true": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
|
||||||
|
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
|
||||||
|
},
|
||||||
|
"boolean_not_false": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||||
|
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
|
||||||
|
},
|
||||||
|
"boolean_compare": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
|
||||||
|
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
|
||||||
|
},
|
||||||
|
"boolean_check": {
|
||||||
|
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
|
||||||
|
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||||
|
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
|
||||||
|
},
|
||||||
|
"table_prefix": {
|
||||||
|
SQLite: "`memo`",
|
||||||
|
MySQL: "`memo`",
|
||||||
|
PostgreSQL: "memo",
|
||||||
|
},
|
||||||
|
"timestamp_field": {
|
||||||
|
SQLite: "`memo`.`%s`",
|
||||||
|
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
|
||||||
|
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
|
||||||
|
},
|
||||||
|
"content_like": {
|
||||||
|
SQLite: "`memo`.`content` LIKE ?",
|
||||||
|
MySQL: "`memo`.`content` LIKE ?",
|
||||||
|
PostgreSQL: "memo.content ILIKE ?",
|
||||||
|
},
|
||||||
|
"visibility_in": {
|
||||||
|
SQLite: "`memo`.`visibility` IN (%s)",
|
||||||
|
MySQL: "`memo`.`visibility` IN (%s)",
|
||||||
|
PostgreSQL: "memo.visibility IN (%s)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQL returns the appropriate SQL for the given template and database type
|
||||||
|
func GetSQL(templateName string, dbType TemplateDBType) string {
|
||||||
|
template, exists := SQLTemplates[templateName]
|
||||||
|
if !exists {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch dbType {
|
||||||
|
case SQLiteTemplate:
|
||||||
|
return template.SQLite
|
||||||
|
case MySQLTemplate:
|
||||||
|
return template.MySQL
|
||||||
|
case PostgreSQLTemplate:
|
||||||
|
return template.PostgreSQL
|
||||||
|
default:
|
||||||
|
return template.SQLite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database
|
||||||
|
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
|
||||||
|
switch dbType {
|
||||||
|
case PostgreSQLTemplate:
|
||||||
|
return fmt.Sprintf("$%d", index)
|
||||||
|
default:
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParameterValue returns the appropriate parameter value for the database
|
||||||
|
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
|
||||||
|
switch templateName {
|
||||||
|
case "json_contains_element", "json_contains_tag":
|
||||||
|
if dbType == SQLiteTemplate {
|
||||||
|
return fmt.Sprintf(`%%"%s"%%`, value)
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatPlaceholders formats a list of placeholders for the given database type
|
||||||
|
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
|
||||||
|
placeholders := make([]string, count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
|
||||||
|
}
|
||||||
|
return placeholders
|
||||||
|
}
|
||||||
|
|
@ -12,6 +12,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
return d.convertWithTemplates(ctx, expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
const dbType = filter.MySQLTemplate
|
||||||
|
|
||||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
|
|
@ -21,7 +27,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
|
|
@ -31,7 +37,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
|
@ -44,7 +50,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
|
@ -54,6 +60,39 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
// Check if the left side is a function call like size(tags)
|
||||||
|
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
|
if leftCallExpr.CallExpr.Function == "size" {
|
||||||
|
// Handle size(tags) comparison
|
||||||
|
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||||
|
return errors.New("size function requires exactly one argument")
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if identifier != "tags" {
|
||||||
|
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||||
|
}
|
||||||
|
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
valueInt, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("size comparison value must be an integer")
|
||||||
|
}
|
||||||
|
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||||
|
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -65,38 +104,19 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "="
|
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||||
switch v.CallExpr.Function {
|
|
||||||
case "_==_":
|
|
||||||
operator = "="
|
|
||||||
case "_!=_":
|
|
||||||
operator = "!="
|
|
||||||
case "_<_":
|
|
||||||
operator = "<"
|
|
||||||
case "_>_":
|
|
||||||
operator = ">"
|
|
||||||
case "_<=_":
|
|
||||||
operator = "<="
|
|
||||||
case "_>=_":
|
|
||||||
operator = ">="
|
|
||||||
}
|
|
||||||
|
|
||||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||||
timestampInt, ok := value.(int64)
|
valueInt, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid timestamp value")
|
return errors.New("invalid integer timestamp value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||||
if identifier == "created_ts" {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||||
factor = "UNIX_TIMESTAMP(`memo`.`created_ts`)"
|
|
||||||
} else if identifier == "updated_ts" {
|
|
||||||
factor = "UNIX_TIMESTAMP(`memo`.`updated_ts`)"
|
|
||||||
}
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, timestampInt)
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
} else if identifier == "visibility" || identifier == "content" {
|
} else if identifier == "visibility" || identifier == "content" {
|
||||||
if operator != "=" && operator != "!=" {
|
if operator != "=" && operator != "!=" {
|
||||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||||
|
|
@ -106,13 +126,13 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
return errors.New("invalid string value")
|
return errors.New("invalid string value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
var sqlTemplate string
|
||||||
if identifier == "visibility" {
|
if identifier == "visibility" {
|
||||||
factor = "`memo`.`visibility`"
|
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||||
} else if identifier == "content" {
|
} else if identifier == "content" {
|
||||||
factor = "`memo`.`content`"
|
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueStr)
|
ctx.Args = append(ctx.Args, valueStr)
|
||||||
|
|
@ -125,11 +145,8 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
return errors.New("invalid int value")
|
return errors.New("invalid int value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||||
if identifier == "creator_id" {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||||
factor = "`memo`.`creator_id`"
|
|
||||||
}
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueInt)
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
|
@ -141,15 +158,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid boolean value for has_task_list")
|
return errors.New("invalid boolean value for has_task_list")
|
||||||
}
|
}
|
||||||
|
// Use template for boolean comparison
|
||||||
// In MySQL, we can use JSON_EXTRACT to get the value and compare it to 'true' or 'false'
|
var sqlTemplate string
|
||||||
compareValue := "false"
|
if operator == "=" {
|
||||||
if valueBool {
|
if valueBool {
|
||||||
compareValue = "true"
|
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||||
|
} else {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||||
|
}
|
||||||
|
} else { // operator == "!="
|
||||||
|
if valueBool {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||||
|
} else {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||||
// MySQL uses -> as a shorthand for JSON_EXTRACT
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST('%s' AS JSON)", operator, compareValue)); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -157,6 +181,29 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is "element in collection" syntax
|
||||||
|
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||||
|
// This is "element in collection" - the second argument is the collection
|
||||||
|
if !slices.Contains([]string{"tags"}, identifier) {
|
||||||
|
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "tags" {
|
||||||
|
// Handle "element" in tags
|
||||||
|
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original logic for "identifier in [list]" syntax
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -174,27 +221,26 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
if identifier == "tag" {
|
if identifier == "tag" {
|
||||||
subcodition := []string{}
|
subconditions := []string{}
|
||||||
args := []any{}
|
args := []any{}
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
subcodition, args = append(subcodition, "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)"), append(args, v)
|
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||||
|
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||||
}
|
}
|
||||||
if len(subcodition) == 1 {
|
if len(subconditions) == 1 {
|
||||||
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
|
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, args...)
|
ctx.Args = append(ctx.Args, args...)
|
||||||
} else if identifier == "visibility" {
|
} else if identifier == "visibility" {
|
||||||
placeholder := []string{}
|
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||||
for range values {
|
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||||
placeholder = append(placeholder, "?")
|
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||||
}
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, values...)
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
|
@ -214,7 +260,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil {
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
|
|
@ -222,17 +268,37 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||||
identifier := v.IdentExpr.GetName()
|
identifier := v.IdentExpr.GetName()
|
||||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||||
return errors.Errorf("invalid identifier for %s", identifier)
|
return errors.Errorf("invalid identifier %s", identifier)
|
||||||
}
|
}
|
||||||
if identifier == "pinned" {
|
if identifier == "pinned" {
|
||||||
if _, err := ctx.Buffer.WriteString("`memo`.`pinned` IS TRUE"); err != nil {
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if identifier == "has_task_list" {
|
} else if identifier == "has_task_list" {
|
||||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)"); err != nil {
|
// Handle has_task_list as a standalone boolean identifier
|
||||||
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DB) getComparisonOperator(function string) string {
|
||||||
|
switch function {
|
||||||
|
case "_==_":
|
||||||
|
return "="
|
||||||
|
case "_!=_":
|
||||||
|
return "!="
|
||||||
|
case "_<_":
|
||||||
|
return "<"
|
||||||
|
case "_>_":
|
||||||
|
return ">"
|
||||||
|
case "_<=_":
|
||||||
|
return "<="
|
||||||
|
case "_>=_":
|
||||||
|
return ">="
|
||||||
|
default:
|
||||||
|
return "="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,26 @@ func TestConvertExprToSQL(t *testing.T) {
|
||||||
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
|
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
|
||||||
args: []any{time.Now().Unix() - 60*60*24},
|
args: []any{time.Now().Unix() - 60*60*24},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) == 0`,
|
||||||
|
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||||
|
args: []any{int64(0)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) > 0`,
|
||||||
|
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||||
|
args: []any{int64(0)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `"work" in tags`,
|
||||||
|
want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||||
|
args: []any{"work"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) == 2`,
|
||||||
|
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||||
|
args: []any{int64(2)},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
|
||||||
|
|
@ -12,219 +12,315 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
const dbType = filter.PostgreSQLTemplate
|
||||||
|
_, err := d.convertWithParameterIndex(ctx, expr, dbType, len(ctx.Args)+1)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.Expr, dbType filter.TemplateDBType, paramIndex int) (int, error) {
|
||||||
|
|
||||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex)
|
||||||
return err
|
if err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
if v.CallExpr.Function == "_||_" {
|
if v.CallExpr.Function == "_||_" {
|
||||||
operator = "OR"
|
operator = "OR"
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
newParamIndex, err = d.convertWithParameterIndex(ctx, v.CallExpr.Args[1], dbType, newParamIndex)
|
||||||
return err
|
if err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
|
return newParamIndex, nil
|
||||||
case "!_":
|
case "!_":
|
||||||
if len(v.CallExpr.Args) != 1 {
|
if len(v.CallExpr.Args) != 1 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex)
|
||||||
return err
|
if err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
|
return newParamIndex, nil
|
||||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
// Check if the left side is a function call like size(tags)
|
||||||
|
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
|
if leftCallExpr.CallExpr.Function == "size" {
|
||||||
|
// Handle size(tags) comparison
|
||||||
|
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||||
|
return paramIndex, errors.New("size function requires exactly one argument")
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return paramIndex, err
|
||||||
|
}
|
||||||
|
if identifier != "tags" {
|
||||||
|
return paramIndex, errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||||
|
}
|
||||||
|
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return paramIndex, err
|
||||||
|
}
|
||||||
|
valueInt, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return paramIndex, errors.New("size comparison value must be an integer")
|
||||||
|
}
|
||||||
|
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||||
|
filter.GetSQL("json_array_length", dbType), operator,
|
||||||
|
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
return paramIndex + 1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
|
||||||
operator := "="
|
|
||||||
switch v.CallExpr.Function {
|
|
||||||
case "_==_":
|
|
||||||
operator = "="
|
|
||||||
case "_!=_":
|
|
||||||
operator = "!="
|
|
||||||
case "_<_":
|
|
||||||
operator = "<"
|
|
||||||
case "_>_":
|
|
||||||
operator = ">"
|
|
||||||
case "_<=_":
|
|
||||||
operator = "<="
|
|
||||||
case "_>=_":
|
|
||||||
operator = ">="
|
|
||||||
}
|
}
|
||||||
|
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||||
|
|
||||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||||
timestampInt, ok := value.(int64)
|
valueInt, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid timestamp value")
|
return paramIndex, errors.New("invalid integer timestamp value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||||
if identifier == "created_ts" {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampSQL, operator,
|
||||||
factor = "EXTRACT(EPOCH FROM memo.created_ts)"
|
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||||
} else if identifier == "updated_ts" {
|
return paramIndex, err
|
||||||
factor = "EXTRACT(EPOCH FROM memo.updated_ts)"
|
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil {
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
return err
|
return paramIndex + 1, nil
|
||||||
}
|
|
||||||
ctx.Args = append(ctx.Args, timestampInt)
|
|
||||||
} else if identifier == "visibility" || identifier == "content" {
|
} else if identifier == "visibility" || identifier == "content" {
|
||||||
if operator != "=" && operator != "!=" {
|
if operator != "=" && operator != "!=" {
|
||||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
valueStr, ok := value.(string)
|
valueStr, ok := value.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid string value")
|
return paramIndex, errors.New("invalid string value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
var sqlTemplate string
|
||||||
if identifier == "visibility" {
|
if identifier == "visibility" {
|
||||||
factor = "memo.visibility"
|
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".visibility"
|
||||||
} else if identifier == "content" {
|
} else if identifier == "content" {
|
||||||
factor = "memo.content"
|
sqlTemplate = filter.GetSQL("content_like", dbType)
|
||||||
|
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", valueStr))
|
||||||
|
return paramIndex + 1, nil
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator,
|
||||||
return err
|
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueStr)
|
ctx.Args = append(ctx.Args, valueStr)
|
||||||
|
return paramIndex + 1, nil
|
||||||
} else if identifier == "creator_id" {
|
} else if identifier == "creator_id" {
|
||||||
if operator != "=" && operator != "!=" {
|
if operator != "=" && operator != "!=" {
|
||||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
valueInt, ok := value.(int64)
|
valueInt, ok := value.(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid int value")
|
return paramIndex, errors.New("invalid int value")
|
||||||
}
|
}
|
||||||
|
|
||||||
factor := "memo.creator_id"
|
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".creator_id"
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator,
|
||||||
return err
|
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueInt)
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
return paramIndex + 1, nil
|
||||||
} else if identifier == "has_task_list" {
|
} else if identifier == "has_task_list" {
|
||||||
if operator != "=" && operator != "!=" {
|
if operator != "=" && operator != "!=" {
|
||||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
valueBool, ok := value.(bool)
|
valueBool, ok := value.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid boolean value for has_task_list")
|
return paramIndex, errors.New("invalid boolean value for has_task_list")
|
||||||
}
|
}
|
||||||
|
// Use parameterized template for boolean comparison (PostgreSQL only)
|
||||||
// In PostgreSQL, extract the boolean from the JSON and compare it
|
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(memo.payload->'property'->>'hasTaskList')::boolean %s %s", operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil {
|
sqlTemplate := fmt.Sprintf(filter.GetSQL("boolean_compare", dbType), operator)
|
||||||
return err
|
sqlTemplate = strings.Replace(sqlTemplate, "?", placeholder, 1)
|
||||||
|
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueBool)
|
ctx.Args = append(ctx.Args, valueBool)
|
||||||
|
return paramIndex + 1, nil
|
||||||
}
|
}
|
||||||
case "@in":
|
case "@in":
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is "element in collection" syntax
|
||||||
|
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||||
|
// This is "element in collection" - the second argument is the collection
|
||||||
|
if !slices.Contains([]string{"tags"}, identifier) {
|
||||||
|
return paramIndex, errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "tags" {
|
||||||
|
// Handle "element" in tags
|
||||||
|
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return paramIndex, errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||||
|
}
|
||||||
|
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||||
|
sql := strings.Replace(filter.GetSQL("json_contains_element", dbType), "?", placeholder, 1)
|
||||||
|
if _, err := ctx.Buffer.WriteString(sql); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||||
|
return paramIndex + 1, nil
|
||||||
|
}
|
||||||
|
return paramIndex, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original logic for "identifier in [list]" syntax
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []any{}
|
values := []any{}
|
||||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||||
value, err := filter.GetConstValue(element)
|
value, err := filter.GetConstValue(element)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
if identifier == "tag" {
|
if identifier == "tag" {
|
||||||
subcodition := []string{}
|
subconditions := []string{}
|
||||||
args := []any{}
|
args := []any{}
|
||||||
|
currentParamIndex := paramIndex
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
subcodition, args = append(subcodition, fmt.Sprintf(`memo.payload->'tags' @> jsonb_build_array(%s)`, placeholder(len(ctx.Args)+ctx.ArgsOffset+len(args)+1))), append(args, v)
|
// Use parameter index for each placeholder
|
||||||
|
placeholder := filter.GetParameterPlaceholder(dbType, currentParamIndex)
|
||||||
|
subcondition := strings.Replace(filter.GetSQL("json_contains_tag", dbType), "?", placeholder, 1)
|
||||||
|
subconditions = append(subconditions, subcondition)
|
||||||
|
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||||
|
currentParamIndex++
|
||||||
}
|
}
|
||||||
if len(subcodition) == 1 {
|
if len(subconditions) == 1 {
|
||||||
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
|
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, args...)
|
ctx.Args = append(ctx.Args, args...)
|
||||||
|
return paramIndex + len(args), nil
|
||||||
} else if identifier == "visibility" {
|
} else if identifier == "visibility" {
|
||||||
placeholders := []string{}
|
placeholders := filter.FormatPlaceholders(dbType, len(values), paramIndex)
|
||||||
for i := range values {
|
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||||
placeholders = append(placeholders, placeholder(len(ctx.Args)+ctx.ArgsOffset+i+1))
|
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||||
}
|
return paramIndex, err
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("memo.visibility IN (%s)", strings.Join(placeholders, ","))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, values...)
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
return paramIndex + len(values), nil
|
||||||
}
|
}
|
||||||
case "contains":
|
case "contains":
|
||||||
if len(v.CallExpr.Args) != 1 {
|
if len(v.CallExpr.Args) != 1 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if identifier != "content" {
|
if identifier != "content" {
|
||||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString("memo.content ILIKE " + placeholder(len(ctx.Args)+ctx.ArgsOffset+1)); err != nil {
|
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||||
return err
|
sql := strings.Replace(filter.GetSQL("content_like", dbType), "?", placeholder, 1)
|
||||||
|
if _, err := ctx.Buffer.WriteString(sql); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
|
return paramIndex + 1, nil
|
||||||
}
|
}
|
||||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||||
identifier := v.IdentExpr.GetName()
|
identifier := v.IdentExpr.GetName()
|
||||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||||
return errors.Errorf("invalid identifier %s", identifier)
|
return paramIndex, errors.Errorf("invalid identifier %s", identifier)
|
||||||
}
|
}
|
||||||
if identifier == "pinned" {
|
if identifier == "pinned" {
|
||||||
if _, err := ctx.Buffer.WriteString("memo.pinned IS TRUE"); err != nil {
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".pinned IS TRUE"); err != nil {
|
||||||
return err
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
} else if identifier == "has_task_list" {
|
} else if identifier == "has_task_list" {
|
||||||
if _, err := ctx.Buffer.WriteString("(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE"); err != nil {
|
// Handle has_task_list as a standalone boolean identifier
|
||||||
return err
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||||
|
return paramIndex, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return paramIndex, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DB) getComparisonOperator(function string) string {
|
||||||
|
switch function {
|
||||||
|
case "_==_":
|
||||||
|
return "="
|
||||||
|
case "_!=_":
|
||||||
|
return "!="
|
||||||
|
case "_<_":
|
||||||
|
return "<"
|
||||||
|
case "_>_":
|
||||||
|
return ">"
|
||||||
|
case "_<=_":
|
||||||
|
return "<="
|
||||||
|
case "_>=_":
|
||||||
|
return ">="
|
||||||
|
default:
|
||||||
|
return "="
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,26 @@ func TestRestoreExprToSQL(t *testing.T) {
|
||||||
want: "EXTRACT(EPOCH FROM memo.created_ts) > $1",
|
want: "EXTRACT(EPOCH FROM memo.created_ts) > $1",
|
||||||
args: []any{time.Now().Unix() - 60*60*24},
|
args: []any{time.Now().Unix() - 60*60*24},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) == 0`,
|
||||||
|
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
|
||||||
|
args: []any{int64(0)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) > 0`,
|
||||||
|
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1",
|
||||||
|
args: []any{int64(0)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `"work" in tags`,
|
||||||
|
want: "memo.payload->'tags' @> jsonb_build_array($1)",
|
||||||
|
args: []any{"work"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) == 2`,
|
||||||
|
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
|
||||||
|
args: []any{int64(2)},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
return d.convertWithTemplates(ctx, expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
const dbType = filter.SQLiteTemplate
|
||||||
|
|
||||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
|
|
@ -21,7 +27,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
|
|
@ -31,7 +37,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
|
@ -44,7 +50,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
|
@ -54,6 +60,39 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
// Check if the left side is a function call like size(tags)
|
||||||
|
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
|
if leftCallExpr.CallExpr.Function == "size" {
|
||||||
|
// Handle size(tags) comparison
|
||||||
|
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||||
|
return errors.New("size function requires exactly one argument")
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if identifier != "tags" {
|
||||||
|
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||||
|
}
|
||||||
|
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
valueInt, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("size comparison value must be an integer")
|
||||||
|
}
|
||||||
|
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||||
|
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||||
|
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -65,21 +104,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "="
|
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||||
switch v.CallExpr.Function {
|
|
||||||
case "_==_":
|
|
||||||
operator = "="
|
|
||||||
case "_!=_":
|
|
||||||
operator = "!="
|
|
||||||
case "_<_":
|
|
||||||
operator = "<"
|
|
||||||
case "_>_":
|
|
||||||
operator = ">"
|
|
||||||
case "_<=_":
|
|
||||||
operator = "<="
|
|
||||||
case "_>=_":
|
|
||||||
operator = ">="
|
|
||||||
}
|
|
||||||
|
|
||||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||||
valueInt, ok := value.(int64)
|
valueInt, ok := value.(int64)
|
||||||
|
|
@ -87,13 +112,8 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
return errors.New("invalid integer timestamp value")
|
return errors.New("invalid integer timestamp value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||||
if identifier == "created_ts" {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||||
factor = "`memo`.`created_ts`"
|
|
||||||
} else if identifier == "updated_ts" {
|
|
||||||
factor = "`memo`.`updated_ts`"
|
|
||||||
}
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueInt)
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
|
@ -106,13 +126,13 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
return errors.New("invalid string value")
|
return errors.New("invalid string value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
var sqlTemplate string
|
||||||
if identifier == "visibility" {
|
if identifier == "visibility" {
|
||||||
factor = "`memo`.`visibility`"
|
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||||
} else if identifier == "content" {
|
} else if identifier == "content" {
|
||||||
factor = "`memo`.`content`"
|
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueStr)
|
ctx.Args = append(ctx.Args, valueStr)
|
||||||
|
|
@ -125,11 +145,8 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
return errors.New("invalid int value")
|
return errors.New("invalid int value")
|
||||||
}
|
}
|
||||||
|
|
||||||
var factor string
|
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||||
if identifier == "creator_id" {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||||
factor = "`memo`.`creator_id`"
|
|
||||||
}
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, valueInt)
|
ctx.Args = append(ctx.Args, valueInt)
|
||||||
|
|
@ -141,12 +158,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid boolean value for has_task_list")
|
return errors.New("invalid boolean value for has_task_list")
|
||||||
}
|
}
|
||||||
// In SQLite JSON boolean values are 1 for true and 0 for false
|
// Use template for boolean comparison
|
||||||
compareValue := 0
|
var sqlTemplate string
|
||||||
if valueBool {
|
if operator == "=" {
|
||||||
compareValue = 1
|
if valueBool {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||||
|
} else {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||||
|
}
|
||||||
|
} else { // operator == "!="
|
||||||
|
if valueBool {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||||
|
} else {
|
||||||
|
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s %d", operator, compareValue)); err != nil {
|
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -154,6 +181,29 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is "element in collection" syntax
|
||||||
|
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||||
|
// This is "element in collection" - the second argument is the collection
|
||||||
|
if !slices.Contains([]string{"tags"}, identifier) {
|
||||||
|
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "tags" {
|
||||||
|
// Handle "element" in tags
|
||||||
|
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original logic for "identifier in [list]" syntax
|
||||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -171,27 +221,26 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
if identifier == "tag" {
|
if identifier == "tag" {
|
||||||
subcodition := []string{}
|
subconditions := []string{}
|
||||||
args := []any{}
|
args := []any{}
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
subcodition, args = append(subcodition, "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?"), append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||||
|
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||||
}
|
}
|
||||||
if len(subcodition) == 1 {
|
if len(subconditions) == 1 {
|
||||||
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
|
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, args...)
|
ctx.Args = append(ctx.Args, args...)
|
||||||
} else if identifier == "visibility" {
|
} else if identifier == "visibility" {
|
||||||
placeholder := []string{}
|
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||||
for range values {
|
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||||
placeholder = append(placeholder, "?")
|
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||||
}
|
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, values...)
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
|
@ -211,7 +260,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil {
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
|
|
@ -222,15 +271,34 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
|
||||||
return errors.Errorf("invalid identifier %s", identifier)
|
return errors.Errorf("invalid identifier %s", identifier)
|
||||||
}
|
}
|
||||||
if identifier == "pinned" {
|
if identifier == "pinned" {
|
||||||
if _, err := ctx.Buffer.WriteString("`memo`.`pinned` IS TRUE"); err != nil {
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if identifier == "has_task_list" {
|
} else if identifier == "has_task_list" {
|
||||||
// Handle has_task_list as a standalone boolean identifier
|
// Handle has_task_list as a standalone boolean identifier
|
||||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE"); err != nil {
|
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DB) getComparisonOperator(function string) string {
|
||||||
|
switch function {
|
||||||
|
case "_==_":
|
||||||
|
return "="
|
||||||
|
case "_!=_":
|
||||||
|
return "!="
|
||||||
|
case "_<_":
|
||||||
|
return "<"
|
||||||
|
case "_>_":
|
||||||
|
return ">"
|
||||||
|
case "_<=_":
|
||||||
|
return "<="
|
||||||
|
case "_>=_":
|
||||||
|
return ">="
|
||||||
|
default:
|
||||||
|
return "="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -110,14 +110,40 @@ func TestConvertExprToSQL(t *testing.T) {
|
||||||
want: "`memo`.`created_ts` > ?",
|
want: "`memo`.`created_ts` > ?",
|
||||||
args: []any{time.Now().Unix() - 60*60*24},
|
args: []any{time.Now().Unix() - 60*60*24},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) == 0`,
|
||||||
|
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||||
|
args: []any{int64(0)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) > 0`,
|
||||||
|
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||||
|
args: []any{int64(0)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `"work" in tags`,
|
||||||
|
want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||||
|
args: []any{`%"work"%`},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `size(tags) == 2`,
|
||||||
|
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||||
|
args: []any{int64(2)},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
db := &DB{}
|
db := &DB{}
|
||||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to parse filter: %s, error: %v", tt.filter, err)
|
||||||
|
}
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to convert filter: %s, error: %v", tt.filter, err)
|
||||||
|
}
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
require.Equal(t, tt.args, convertCtx.Args)
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue