mirror of https://github.com/usememos/memos.git
feat: filter/method for reactions by content_id (#4969)
This commit is contained in:
parent
931ddb7c1c
commit
f4bdfa28a0
|
|
@ -205,7 +205,7 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
if !slices.Contains([]string{"tag", "visibility", "content_id"}, identifier) {
|
||||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -222,6 +222,8 @@ func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exp
|
||||||
return c.handleTagInList(ctx, values)
|
return c.handleTagInList(ctx, values)
|
||||||
} else if identifier == "visibility" {
|
} else if identifier == "visibility" {
|
||||||
return c.handleVisibilityInList(ctx, values)
|
return c.handleVisibilityInList(ctx, values)
|
||||||
|
} else if identifier == "content_id" {
|
||||||
|
return c.handleContentIDInList(ctx, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -292,7 +294,7 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values
|
||||||
c.paramIndex++
|
c.paramIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
tablePrefix := c.dialect.GetTablePrefix()
|
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -307,6 +309,28 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values []any) error {
|
||||||
|
placeholders := []string{}
|
||||||
|
for range values {
|
||||||
|
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||||
|
c.paramIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
tablePrefix := c.dialect.GetTablePrefix("reaction")
|
||||||
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||||
if len(callExpr.Args) != 1 {
|
if len(callExpr.Args) != 1 {
|
||||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||||
|
|
@ -326,7 +350,7 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tablePrefix := c.dialect.GetTablePrefix()
|
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||||
|
|
||||||
// PostgreSQL uses ILIKE and no backticks
|
// PostgreSQL uses ILIKE and no backticks
|
||||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
|
|
@ -353,7 +377,7 @@ func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *ex
|
||||||
}
|
}
|
||||||
|
|
||||||
if identifier == "pinned" {
|
if identifier == "pinned" {
|
||||||
tablePrefix := c.dialect.GetTablePrefix()
|
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -411,7 +435,7 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field,
|
||||||
return errors.New("invalid string value")
|
return errors.New("invalid string value")
|
||||||
}
|
}
|
||||||
|
|
||||||
tablePrefix := c.dialect.GetTablePrefix()
|
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||||
|
|
||||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
// PostgreSQL doesn't use backticks
|
// PostgreSQL doesn't use backticks
|
||||||
|
|
@ -447,7 +471,7 @@ func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, ope
|
||||||
return errors.New("invalid int value")
|
return errors.New("invalid int value")
|
||||||
}
|
}
|
||||||
|
|
||||||
tablePrefix := c.dialect.GetTablePrefix()
|
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||||
|
|
||||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||||
// PostgreSQL doesn't use backticks
|
// PostgreSQL doesn't use backticks
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
// SQLDialect defines database-specific SQL generation methods.
|
// SQLDialect defines database-specific SQL generation methods.
|
||||||
type SQLDialect interface {
|
type SQLDialect interface {
|
||||||
// Basic field access
|
// Basic field access
|
||||||
GetTablePrefix() string
|
GetTablePrefix(entityName string) string
|
||||||
GetParameterPlaceholder(index int) string
|
GetParameterPlaceholder(index int) string
|
||||||
|
|
||||||
// JSON operations
|
// JSON operations
|
||||||
|
|
@ -53,8 +53,8 @@ func GetDialect(dbType DatabaseType) SQLDialect {
|
||||||
// SQLiteDialect implements SQLDialect for SQLite.
|
// SQLiteDialect implements SQLDialect for SQLite.
|
||||||
type SQLiteDialect struct{}
|
type SQLiteDialect struct{}
|
||||||
|
|
||||||
func (*SQLiteDialect) GetTablePrefix() string {
|
func (*SQLiteDialect) GetTablePrefix(entityName string) string {
|
||||||
return "`memo`"
|
return fmt.Sprintf("`%s`", entityName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
||||||
|
|
@ -62,7 +62,7 @@ func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
||||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
||||||
|
|
@ -96,7 +96,7 @@ func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
||||||
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
|
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix("memo"), field)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*SQLiteDialect) GetCurrentTimestamp() string {
|
func (*SQLiteDialect) GetCurrentTimestamp() string {
|
||||||
|
|
@ -106,8 +106,8 @@ func (*SQLiteDialect) GetCurrentTimestamp() string {
|
||||||
// MySQLDialect implements SQLDialect for MySQL.
|
// MySQLDialect implements SQLDialect for MySQL.
|
||||||
type MySQLDialect struct{}
|
type MySQLDialect struct{}
|
||||||
|
|
||||||
func (*MySQLDialect) GetTablePrefix() string {
|
func (*MySQLDialect) GetTablePrefix(entityName string) string {
|
||||||
return "`memo`"
|
return fmt.Sprintf("`%s`", entityName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
||||||
|
|
@ -115,7 +115,7 @@ func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
||||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
||||||
|
|
@ -146,7 +146,7 @@ func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
||||||
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
|
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix("memo"), field)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*MySQLDialect) GetCurrentTimestamp() string {
|
func (*MySQLDialect) GetCurrentTimestamp() string {
|
||||||
|
|
@ -156,8 +156,8 @@ func (*MySQLDialect) GetCurrentTimestamp() string {
|
||||||
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
|
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
|
||||||
type PostgreSQLDialect struct{}
|
type PostgreSQLDialect struct{}
|
||||||
|
|
||||||
func (*PostgreSQLDialect) GetTablePrefix() string {
|
func (*PostgreSQLDialect) GetTablePrefix(entityName string) string {
|
||||||
return "memo"
|
return entityName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||||
|
|
@ -167,7 +167,7 @@ func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||||
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||||
// Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList'
|
// Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList'
|
||||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
|
result := fmt.Sprintf("%s.payload", d.GetTablePrefix("memo"))
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
if i == len(parts)-1 {
|
if i == len(parts)-1 {
|
||||||
result += fmt.Sprintf("->>'%s'", part)
|
result += fmt.Sprintf("->>'%s'", part)
|
||||||
|
|
@ -180,17 +180,17 @@ func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||||
|
|
||||||
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
||||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
|
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix("memo"), jsonPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
|
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
|
||||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix(), jsonPath)
|
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
|
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
|
||||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix(), jsonPath)
|
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
||||||
|
|
@ -207,7 +207,7 @@ func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
||||||
return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix(), field)
|
return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix("memo"), field)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
|
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,11 @@ var MemoFilterCELAttributes = []cel.EnvOption{
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReactionFilterCELAttributes are the CEL attributes for reaction.
|
||||||
|
var ReactionFilterCELAttributes = []cel.EnvOption{
|
||||||
|
cel.Variable("content_id", cel.StringType),
|
||||||
|
}
|
||||||
|
|
||||||
// UserFilterCELAttributes are the CEL attributes for user.
|
// UserFilterCELAttributes are the CEL attributes for user.
|
||||||
var UserFilterCELAttributes = []cel.EnvOption{
|
var UserFilterCELAttributes = []cel.EnvOption{
|
||||||
cel.Variable("username", cel.StringType),
|
cel.Variable("username", cel.StringType),
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to convert memo")
|
return nil, errors.Wrap(err, "failed to convert memo")
|
||||||
}
|
}
|
||||||
|
|
@ -178,8 +178,28 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
|
||||||
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
|
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reactionMap := make(map[string][]*store.Reaction)
|
||||||
|
|
||||||
|
memoNames := make([]string, 0, len(memos))
|
||||||
|
for _, m := range memos {
|
||||||
|
memoNames = append(memoNames, fmt.Sprintf("'%s/%s'", MemoNamePrefix, m.UID))
|
||||||
|
}
|
||||||
|
|
||||||
|
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||||
|
Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNames, ", "))},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, reaction := range reactions {
|
||||||
|
reactionMap[reaction.ContentID] = append(reactionMap[reaction.ContentID], reaction)
|
||||||
|
}
|
||||||
|
|
||||||
for _, memo := range memos {
|
for _, memo := range memos {
|
||||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
name := fmt.Sprintf("'%s/%s'", MemoNamePrefix, memo.UID)
|
||||||
|
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactionMap[name])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to convert memo")
|
return nil, errors.Wrap(err, "failed to convert memo")
|
||||||
}
|
}
|
||||||
|
|
@ -220,7 +240,7 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to convert memo")
|
return nil, errors.Wrap(err, "failed to convert memo")
|
||||||
}
|
}
|
||||||
|
|
@ -339,7 +359,7 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to get memo")
|
return nil, errors.Wrap(err, "failed to get memo")
|
||||||
}
|
}
|
||||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to convert memo")
|
return nil, errors.Wrap(err, "failed to convert memo")
|
||||||
}
|
}
|
||||||
|
|
@ -375,7 +395,7 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
||||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
|
if memoMessage, err := s.convertMemoFromStore(ctx, memo, nil); err == nil {
|
||||||
// Try to dispatch webhook when memo is deleted.
|
// Try to dispatch webhook when memo is deleted.
|
||||||
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
||||||
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
|
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
|
||||||
|
|
@ -530,7 +550,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
|
||||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||||
}
|
}
|
||||||
if memo != nil {
|
if memo != nil {
|
||||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to convert memo")
|
return nil, errors.Wrap(err, "failed to convert memo")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/usememos/gomark/parser"
|
"github.com/usememos/gomark/parser"
|
||||||
|
|
@ -16,7 +18,7 @@ import (
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
|
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction) (*v1pb.Memo, error) {
|
||||||
displayTs := memo.CreatedTs
|
displayTs := memo.CreatedTs
|
||||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -61,11 +63,24 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
|
||||||
}
|
}
|
||||||
memoMessage.Attachments = listMemoAttachmentsResponse.Attachments
|
memoMessage.Attachments = listMemoAttachmentsResponse.Attachments
|
||||||
|
|
||||||
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
|
if len(reactions) > 0 {
|
||||||
if err != nil {
|
for _, reaction := range reactions {
|
||||||
return nil, errors.Wrap(err, "failed to list memo reactions")
|
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||||
|
}
|
||||||
|
memoMessage.Reactions = append(memoMessage.Reactions, reactionMessage)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// done for backwards compatibility
|
||||||
|
// can remove once convertMemoFromStore is only responsible for mapping
|
||||||
|
// and all related DB entities are passed in as arguments purely for converting to request entities
|
||||||
|
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to list memo reactions")
|
||||||
|
}
|
||||||
|
memoMessage.Reactions = listMemoReactionsResponse.Reactions
|
||||||
}
|
}
|
||||||
memoMessage.Reactions = listMemoReactionsResponse.Reactions
|
|
||||||
|
|
||||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -36,6 +38,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
|
||||||
|
|
||||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||||
where, args := []string{"1 = 1"}, []interface{}{}
|
where, args := []string{"1 = 1"}, []interface{}{}
|
||||||
|
|
||||||
|
for _, filterStr := range find.Filters {
|
||||||
|
// Parse filter string and return the parsed expression.
|
||||||
|
// The filter string should be a CEL expression.
|
||||||
|
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
|
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||||
|
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
condition := convertCtx.Buffer.String()
|
||||||
|
if condition != "" {
|
||||||
|
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||||
|
args = append(args, convertCtx.Args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if find.ID != nil {
|
if find.ID != nil {
|
||||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filter string
|
||||||
|
want string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||||
|
want: "`reaction`.`content_id` IN (?)",
|
||||||
|
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||||
|
want: "`reaction`.`content_id` IN (?,?)",
|
||||||
|
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||||
|
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,8 +2,10 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -24,6 +26,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
|
||||||
|
|
||||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||||
where, args := []string{"1 = 1"}, []interface{}{}
|
where, args := []string{"1 = 1"}, []interface{}{}
|
||||||
|
|
||||||
|
for _, filterStr := range find.Filters {
|
||||||
|
// Parse filter string and return the parsed expression.
|
||||||
|
// The filter string should be a CEL expression.
|
||||||
|
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
|
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
|
||||||
|
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
condition := convertCtx.Buffer.String()
|
||||||
|
if condition != "" {
|
||||||
|
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||||
|
args = append(args, convertCtx.Args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if find.ID != nil {
|
if find.ID != nil {
|
||||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filter string
|
||||||
|
want string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||||
|
want: "reaction.content_id IN ($1)",
|
||||||
|
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||||
|
want: "reaction.content_id IN ($1,$2)",
|
||||||
|
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
|
||||||
|
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,8 +2,10 @@ package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -25,6 +27,27 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
|
||||||
|
|
||||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||||
where, args := []string{"1 = 1"}, []interface{}{}
|
where, args := []string{"1 = 1"}, []interface{}{}
|
||||||
|
|
||||||
|
for _, filterStr := range find.Filters {
|
||||||
|
// Parse filter string and return the parsed expression.
|
||||||
|
// The filter string should be a CEL expression.
|
||||||
|
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
|
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||||
|
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
condition := convertCtx.Buffer.String()
|
||||||
|
if condition != "" {
|
||||||
|
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||||
|
args = append(args, convertCtx.Args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if find.ID != nil {
|
if find.ID != nil {
|
||||||
where, args = append(where, "id = ?"), append(args, *find.ID)
|
where, args = append(where, "id = ?"), append(args, *find.ID)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filter string
|
||||||
|
want string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||||
|
want: "`reaction`.`content_id` IN (?)",
|
||||||
|
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||||
|
want: "`reaction`.`content_id` IN (?,?)",
|
||||||
|
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||||
|
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -17,6 +17,7 @@ type FindReaction struct {
|
||||||
ID *int32
|
ID *int32
|
||||||
CreatorID *int32
|
CreatorID *int32
|
||||||
ContentID *string
|
ContentID *string
|
||||||
|
Filters []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeleteReaction struct {
|
type DeleteReaction struct {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue