mirror of https://github.com/usememos/memos.git
feat: filter/method for reactions by content_id (#4969)
This commit is contained in:
parent
931ddb7c1c
commit
f4bdfa28a0
|
|
@ -205,7 +205,7 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp
|
|||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
if !slices.Contains([]string{"tag", "visibility", "content_id"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
|
|
@ -222,6 +222,8 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp
|
|||
return c.handleTagInList(ctx, values)
|
||||
} else if identifier == "visibility" {
|
||||
return c.handleVisibilityInList(ctx, values)
|
||||
} else if identifier == "content_id" {
|
||||
return c.handleContentIDInList(ctx, values)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -292,7 +294,7 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values
|
|||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
|
|
@ -307,6 +309,28 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values []any) error {
|
||||
placeholders := []string{}
|
||||
for range values {
|
||||
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("reaction")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
|
|
@ -326,7 +350,7 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp
|
|||
return err
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
// PostgreSQL uses ILIKE and no backticks
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
|
|
@ -353,7 +377,7 @@ func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *ex
|
|||
}
|
||||
|
||||
if identifier == "pinned" {
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
|
|
@ -411,7 +435,7 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field,
|
|||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
|
|
@ -447,7 +471,7 @@ func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, ope
|
|||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import (
|
|||
// SQLDialect defines database-specific SQL generation methods.
|
||||
type SQLDialect interface {
|
||||
// Basic field access
|
||||
GetTablePrefix() string
|
||||
GetTablePrefix(entityName string) string
|
||||
GetParameterPlaceholder(index int) string
|
||||
|
||||
// JSON operations
|
||||
|
|
@ -53,8 +53,8 @@ func GetDialect(dbType DatabaseType) SQLDialect {
|
|||
// SQLiteDialect implements SQLDialect for SQLite.
|
||||
type SQLiteDialect struct{}
|
||||
|
||||
func (*SQLiteDialect) GetTablePrefix() string {
|
||||
return "`memo`"
|
||||
func (*SQLiteDialect) GetTablePrefix(entityName string) string {
|
||||
return fmt.Sprintf("`%s`", entityName)
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
||||
|
|
@ -62,7 +62,7 @@ func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
|||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
||||
|
|
@ -96,7 +96,7 @@ func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
|||
}
|
||||
|
||||
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
|
||||
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix("memo"), field)
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetCurrentTimestamp() string {
|
||||
|
|
@ -106,8 +106,8 @@ func (*SQLiteDialect) GetCurrentTimestamp() string {
|
|||
// MySQLDialect implements SQLDialect for MySQL.
|
||||
type MySQLDialect struct{}
|
||||
|
||||
func (*MySQLDialect) GetTablePrefix() string {
|
||||
return "`memo`"
|
||||
func (*MySQLDialect) GetTablePrefix(entityName string) string {
|
||||
return fmt.Sprintf("`%s`", entityName)
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
||||
|
|
@ -115,7 +115,7 @@ func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
|||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
||||
|
|
@ -146,7 +146,7 @@ func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
|||
}
|
||||
|
||||
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
|
||||
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix("memo"), field)
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetCurrentTimestamp() string {
|
||||
|
|
@ -156,8 +156,8 @@ func (*MySQLDialect) GetCurrentTimestamp() string {
|
|||
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
|
||||
type PostgreSQLDialect struct{}
|
||||
|
||||
func (*PostgreSQLDialect) GetTablePrefix() string {
|
||||
return "memo"
|
||||
func (*PostgreSQLDialect) GetTablePrefix(entityName string) string {
|
||||
return entityName
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||
|
|
@ -167,7 +167,7 @@ func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
|||
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||
// Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList'
|
||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
|
||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix("memo"))
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
result += fmt.Sprintf("->>'%s'", part)
|
||||
|
|
@ -180,17 +180,17 @@ func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
|||
|
||||
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix("memo"), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix(), jsonPath)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix(), jsonPath)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
|
|
@ -207,7 +207,7 @@ func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
|||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix(), field)
|
||||
return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix("memo"), field)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ var MemoFilterCELAttributes = []cel.EnvOption{
|
|||
),
|
||||
}
|
||||
|
||||
// ReactionFilterCELAttributes are the CEL attributes for reaction.
|
||||
var ReactionFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content_id", cel.StringType),
|
||||
}
|
||||
|
||||
// UserFilterCELAttributes are the CEL attributes for user.
|
||||
var UserFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("username", cel.StringType),
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -178,8 +178,28 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
|
|||
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
reactionMap := make(map[string][]*store.Reaction)
|
||||
|
||||
memoNames := make([]string, 0, len(memos))
|
||||
for _, m := range memos {
|
||||
memoNames = append(memoNames, fmt.Sprintf("'%s/%s'", MemoNamePrefix, m.UID))
|
||||
}
|
||||
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNames, ", "))},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
for _, reaction := range reactions {
|
||||
reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction)
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
name := fmt.Sprintf("'%s/%s'", MemoNamePrefix, memo.UID)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactionMap[name])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -220,7 +240,7 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
|
|||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -339,7 +359,7 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -375,7 +395,7 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo, nil); err == nil {
|
||||
// Try to dispatch webhook when memo is deleted.
|
||||
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
|
||||
|
|
@ -530,7 +550,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
|
|||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo != nil {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/gomark/parser"
|
||||
|
|
@ -16,7 +18,7 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction) (*v1pb.Memo, error) {
|
||||
displayTs := memo.CreatedTs
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -61,11 +63,24 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
|
|||
}
|
||||
memoMessage.Attachments = listMemoAttachmentsResponse.Attachments
|
||||
|
||||
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo reactions")
|
||||
if len(reactions) > 0 {
|
||||
for _, reaction := range reactions {
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
memoMessage.Reactions = append(memoMessage.Reactions, reactionMessage)
|
||||
}
|
||||
} else {
|
||||
// done for backwards compatibility
|
||||
// can remove once convertMemoFromStore is only responsible for mapping
|
||||
// and all related DB entities are passed in as arguments purely for converting to request entities
|
||||
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo reactions")
|
||||
}
|
||||
memoMessage.Reactions = listMemoReactionsResponse.Reactions
|
||||
}
|
||||
memoMessage.Reactions = listMemoReactionsResponse.Reactions
|
||||
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ package mysql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
|
@ -36,6 +38,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
|
|||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
|
||||
for _, filterStr := range find.Filters {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "`reaction`.`content_id` IN (?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "`reaction`.`content_id` IN (?,?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,8 +2,10 @@ package postgres
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
|
@ -24,6 +26,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
|
|||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
|
||||
for _, filterStr := range find.Filters {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "reaction.content_id IN ($1)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "reaction.content_id IN ($1,$2)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,8 +2,10 @@ package sqlite
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
|
@ -25,6 +27,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
|
|||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
|
||||
for _, filterStr := range find.Filters {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *find.ID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "`reaction`.`content_id` IN (?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "`reaction`.`content_id` IN (?,?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ type FindReaction struct {
|
|||
ID *int32
|
||||
CreatorID *int32
|
||||
ContentID *string
|
||||
Filters []string
|
||||
}
|
||||
|
||||
type DeleteReaction struct {
|
||||
|
|
|
|||
Loading…
Reference in New Issue