diff --git a/store/db/sqlite/reaction.go b/store/db/sqlite/reaction.go index 47a582716..c4edfd613 100644 --- a/store/db/sqlite/reaction.go +++ b/store/db/sqlite/reaction.go @@ -2,6 +2,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "strings" "github.com/usememos/memos/store" @@ -88,15 +90,43 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st } func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) { - list, err := d.ListReactions(ctx, find) - if err != nil { - return nil, err + where, args := []string{"1 = 1"}, []any{} + + if find.ID != nil { + where, args = append(where, "id = ?"), append(args, *find.ID) } - if len(list) == 0 { - return nil, nil + if find.CreatorID != nil { + where, args = append(where, "creator_id = ?"), append(args, *find.CreatorID) + } + if find.ContentID != nil { + where, args = append(where, "content_id = ?"), append(args, *find.ContentID) + } + + reaction := &store.Reaction{} + if err := d.db.QueryRowContext(ctx, ` + SELECT + id, + created_ts, + creator_id, + content_id, + reaction_type + FROM reaction + WHERE `+strings.Join(where, " AND ")+` + LIMIT 1`, + args..., + ).Scan( + &reaction.ID, + &reaction.CreatedTs, + &reaction.CreatorID, + &reaction.ContentID, + &reaction.ReactionType, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err } - reaction := list[0] return reaction, nil }