mirror of https://github.com/usememos/memos.git
fix(api): improve SSE hub design and fix double-broadcast on comments
- Fix duplicate SSE event on comment creation: CreateMemoComment now suppresses the redundant memo.created broadcast from the inner CreateMemo call, emitting only memo.comment.created - Extract reaction event-building IIFEs into buildMemoReactionSSEEvent helper, removing duplicated inline DB-fetch logic - Promote resolveSSEAudienceCreatorID from method to free function (resolveSSECreatorID) since it never used the receiver - Add userID to SSE connect/disconnect log lines for traceability - Change canReceive default from permissive (return true) to deny-with-warning for unknown visibility types - Add comprehensive tests covering all new helpers, visibility edge cases, slow-client drop behavior, and the double-broadcast fix Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d720efb6e6
commit
c53677fcba
|
|
@ -19,6 +19,19 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// suppressSSEKey is a context key used to suppress the SSE broadcast from
|
||||
// CreateMemo when it is called internally (e.g., from CreateMemoComment).
|
||||
type suppressSSEKey struct{}
|
||||
|
||||
func withSuppressSSE(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, suppressSSEKey{}, true)
|
||||
}
|
||||
|
||||
func isSSESuppressed(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(suppressSSEKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -136,11 +149,15 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCreated,
|
||||
Name: memoMessage.Name,
|
||||
})
|
||||
// Broadcast live refresh event (skipped when called from CreateMemoComment).
|
||||
if !isSSESuppressed(ctx) {
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCreated,
|
||||
Name: memoMessage.Name,
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, nil),
|
||||
})
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
|
@ -501,6 +518,10 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
var parentMemo *store.Memo
|
||||
if memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
// Try to dispatch webhook when memo is updated.
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
|
|
@ -508,8 +529,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
Parent: memoMessage.GetParent(),
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, parentMemo),
|
||||
})
|
||||
|
||||
return memoMessage, nil
|
||||
|
|
@ -583,8 +607,10 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoDeleted,
|
||||
Name: request.Name,
|
||||
Type: SSEEventMemoDeleted,
|
||||
Name: request.Name,
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, nil),
|
||||
})
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
|
|
@ -615,8 +641,9 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Create the memo comment first.
|
||||
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
||||
// Create the memo comment first; suppress the generic memo.created SSE event
|
||||
// since CreateMemoComment broadcasts memo.comment.created for the parent instead.
|
||||
memoComment, err := s.CreateMemo(withSuppressSSE(ctx), &v1pb.CreateMemoRequest{
|
||||
Memo: request.Comment,
|
||||
MemoId: request.CommentId,
|
||||
})
|
||||
|
|
@ -674,8 +701,10 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
|
|||
|
||||
// Broadcast live refresh event for the parent memo so subscribers see the new comment.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCommentCreated,
|
||||
Name: request.Name,
|
||||
Type: SSEEventMemoCommentCreated,
|
||||
Name: request.Name,
|
||||
Visibility: relatedMemo.Visibility,
|
||||
CreatorID: relatedMemo.CreatorID,
|
||||
})
|
||||
|
||||
return memoComment, nil
|
||||
|
|
|
|||
|
|
@ -104,10 +104,11 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
|
|||
}
|
||||
|
||||
// Broadcast live refresh event (reaction belongs to a memo).
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventReactionUpserted,
|
||||
Name: request.Reaction.ContentId,
|
||||
})
|
||||
var parentMemo *store.Memo
|
||||
if memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionUpserted, request.Reaction.ContentId, memo, parentMemo))
|
||||
|
||||
return reactionMessage, nil
|
||||
}
|
||||
|
|
@ -148,11 +149,21 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del
|
|||
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
|
||||
}
|
||||
|
||||
memoUID, err := ExtractMemoUIDFromName(reaction.ContentID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
// Broadcast live refresh event (reaction belongs to a memo).
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventReactionDeleted,
|
||||
Name: reaction.ContentID,
|
||||
})
|
||||
var parentMemo *store.Memo
|
||||
if memo != nil && memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionDeleted, reaction.ContentID, memo, parentMemo))
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
package v1
|
||||
|
||||
import "github.com/usememos/memos/store"
|
||||
|
||||
func buildMemoName(uid string) string {
|
||||
return MemoNamePrefix + uid
|
||||
}
|
||||
|
||||
// resolveSSECreatorID returns the CreatorID used for SSE delivery filtering.
|
||||
// For a comment memo, it returns the parent memo's CreatorID so that private
|
||||
// parent-memo events are scoped to the parent owner.
|
||||
func resolveSSECreatorID(memo *store.Memo, parentMemo *store.Memo) int32 {
|
||||
if memo == nil {
|
||||
return 0
|
||||
}
|
||||
if parentMemo != nil {
|
||||
return parentMemo.CreatorID
|
||||
}
|
||||
return memo.CreatorID
|
||||
}
|
||||
|
||||
// buildMemoReactionSSEEvent constructs an SSEEvent for a reaction on a memo.
|
||||
// Pass parentMemo when the memo is a comment (memo.ParentUID != nil).
|
||||
func buildMemoReactionSSEEvent(eventType SSEEventType, contentID string, memo *store.Memo, parentMemo *store.Memo) *SSEEvent {
|
||||
parent := ""
|
||||
if memo != nil && memo.ParentUID != nil {
|
||||
parent = buildMemoName(*memo.ParentUID)
|
||||
}
|
||||
visibility := store.Visibility("")
|
||||
if memo != nil {
|
||||
visibility = memo.Visibility
|
||||
}
|
||||
return &SSEEvent{
|
||||
Type: eventType,
|
||||
Name: contentID,
|
||||
Parent: parent,
|
||||
Visibility: visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, parentMemo),
|
||||
}
|
||||
}
|
||||
|
|
@ -17,10 +17,14 @@ const (
|
|||
sseHeartbeatInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// RegisterSSERoutes registers the SSE endpoint on the given Echo instance.
|
||||
func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) {
|
||||
type sseRouteRegistrar interface {
|
||||
GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) echo.RouteInfo
|
||||
}
|
||||
|
||||
// RegisterSSERoutes registers the SSE endpoint on the given Echo router.
|
||||
func RegisterSSERoutes(router sseRouteRegistrar, hub *SSEHub, storeInstance *store.Store, secret string) {
|
||||
authenticator := auth.NewAuthenticator(storeInstance, secret)
|
||||
echoServer.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
router.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
return handleSSE(c, hub, authenticator)
|
||||
})
|
||||
}
|
||||
|
|
@ -34,6 +38,10 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
if result == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
userID, role := getSSEClientIdentity(result)
|
||||
if userID == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
|
||||
// Set SSE headers.
|
||||
w := c.Response()
|
||||
|
|
@ -49,7 +57,7 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
}
|
||||
|
||||
// Subscribe to the hub.
|
||||
client := hub.Subscribe()
|
||||
client := hub.Subscribe(userID, role)
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
// Create a ticker for heartbeat pings.
|
||||
|
|
@ -58,13 +66,13 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
|
||||
ctx := c.Request().Context()
|
||||
|
||||
slog.Debug("SSE client connected")
|
||||
slog.Debug("SSE client connected", "userID", userID)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Client disconnected.
|
||||
slog.Debug("SSE client disconnected")
|
||||
slog.Debug("SSE client disconnected", "userID", userID)
|
||||
return nil
|
||||
|
||||
case data, ok := <-client.events:
|
||||
|
|
@ -91,3 +99,16 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getSSEClientIdentity(result *auth.AuthResult) (int32, store.Role) {
|
||||
if result == nil {
|
||||
return 0, store.RoleUser
|
||||
}
|
||||
if result.Claims != nil {
|
||||
return result.Claims.UserID, store.Role(result.Claims.Role)
|
||||
}
|
||||
if result.User != nil {
|
||||
return result.User.ID, result.User.Role
|
||||
}
|
||||
return 0, store.RoleUser
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// SSEEventType represents the type of change event.
|
||||
|
|
@ -24,6 +26,11 @@ type SSEEvent struct {
|
|||
// Name is the affected resource name (e.g., "memos/xxxx").
|
||||
// For reaction events, this is the memo resource name that the reaction belongs to.
|
||||
Name string `json:"name"`
|
||||
// Parent is the parent memo resource name when the affected resource is a comment.
|
||||
Parent string `json:"parent,omitempty"`
|
||||
// Visibility and CreatorID are used only for server-side delivery filtering.
|
||||
Visibility store.Visibility `json:"-"`
|
||||
CreatorID int32 `json:"-"`
|
||||
}
|
||||
|
||||
// JSON returns the JSON representation of the event.
|
||||
|
|
@ -40,6 +47,8 @@ func (e *SSEEvent) JSON() []byte {
|
|||
// SSEClient represents a single SSE connection.
|
||||
type SSEClient struct {
|
||||
events chan []byte
|
||||
userID int32
|
||||
role store.Role
|
||||
}
|
||||
|
||||
// SSEHub manages SSE client connections and broadcasts events.
|
||||
|
|
@ -58,10 +67,12 @@ func NewSSEHub() *SSEHub {
|
|||
|
||||
// Subscribe registers a new client and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (h *SSEHub) Subscribe() *SSEClient {
|
||||
func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient {
|
||||
c := &SSEClient{
|
||||
// Buffer a few events so a slow client doesn't block broadcasting.
|
||||
events: make(chan []byte, 32),
|
||||
userID: userID,
|
||||
role: role,
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.clients[c] = struct{}{}
|
||||
|
|
@ -90,6 +101,9 @@ func (h *SSEHub) Broadcast(event *SSEEvent) {
|
|||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
for c := range h.clients {
|
||||
if !c.canReceive(event) {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case c.events <- data:
|
||||
default:
|
||||
|
|
@ -97,3 +111,15 @@ func (h *SSEHub) Broadcast(event *SSEEvent) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SSEClient) canReceive(event *SSEEvent) bool {
|
||||
switch event.Visibility {
|
||||
case store.Private:
|
||||
return c.userID == event.CreatorID || c.role == store.RoleAdmin
|
||||
case store.Public, store.Protected, "":
|
||||
return true
|
||||
default:
|
||||
slog.Warn("SSE canReceive: unknown visibility type, denying event", "visibility", event.Visibility)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,12 +6,36 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// helpers shared by multiple tests in this file.
|
||||
|
||||
func mustReceive(t *testing.T, ch <-chan []byte, within time.Duration) []byte {
|
||||
t.Helper()
|
||||
select {
|
||||
case data := <-ch:
|
||||
return data
|
||||
case <-time.After(within):
|
||||
t.Fatal("timed out waiting for SSE event")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mustNotReceive(t *testing.T, ch <-chan []byte, within time.Duration) {
|
||||
t.Helper()
|
||||
select {
|
||||
case data := <-ch:
|
||||
t.Fatalf("unexpected SSE event received: %s", data)
|
||||
case <-time.After(within):
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
|
||||
client := hub.Subscribe()
|
||||
client := hub.Subscribe(1, store.RoleUser)
|
||||
require.NotNil(t, client)
|
||||
require.NotNil(t, client.events)
|
||||
|
||||
|
|
@ -25,7 +49,7 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
|
|||
|
||||
func TestSSEHub_Broadcast(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
client := hub.Subscribe()
|
||||
client := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"}
|
||||
|
|
@ -42,9 +66,9 @@ func TestSSEHub_Broadcast(t *testing.T) {
|
|||
|
||||
func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
c1 := hub.Subscribe()
|
||||
c1 := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(c1)
|
||||
c2 := hub.Subscribe()
|
||||
c2 := hub.Subscribe(2, store.RoleUser)
|
||||
defer hub.Unsubscribe(c2)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"}
|
||||
|
|
@ -62,9 +86,144 @@ func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSSEEvent_JSON(t *testing.T) {
|
||||
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"}
|
||||
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789", Parent: "memos/123"}
|
||||
data := e.JSON()
|
||||
require.NotEmpty(t, data)
|
||||
assert.Contains(t, string(data), `"type":"memo.updated"`)
|
||||
assert.Contains(t, string(data), `"name":"memos/789"`)
|
||||
assert.Contains(t, string(data), `"parent":"memos/123"`)
|
||||
}
|
||||
|
||||
func TestSSEHub_PrivateEventsAreScoped(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
owner := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(owner)
|
||||
other := hub.Subscribe(2, store.RoleUser)
|
||||
defer hub.Unsubscribe(other)
|
||||
admin := hub.Subscribe(3, store.RoleAdmin)
|
||||
defer hub.Unsubscribe(admin)
|
||||
|
||||
hub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: "memos/private",
|
||||
Visibility: store.Private,
|
||||
CreatorID: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-owner.events:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("owner should receive private event")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-admin.events:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("admin should receive private event")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-other.events:
|
||||
t.Fatal("non-owner should not receive private event")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEClient_CanReceive_UnknownVisibility(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
client := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
// An event with an unrecognised visibility value should be denied (safe default).
|
||||
hub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: "memos/unknown-vis",
|
||||
Visibility: store.Visibility("CUSTOM"),
|
||||
})
|
||||
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSSEHub_SlowClientEventsDropped(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
// Subscribe but never read, so the channel fills up.
|
||||
slow := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(slow)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/x"}
|
||||
// Send more events than the buffer capacity (32).
|
||||
for range 40 {
|
||||
hub.Broadcast(event) // must not block
|
||||
}
|
||||
|
||||
// At most 32 events should have been queued; the rest were silently dropped.
|
||||
assert.LessOrEqual(t, len(slow.events), 32)
|
||||
}
|
||||
|
||||
func TestResolveSSECreatorID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
memo *store.Memo
|
||||
parentMemo *store.Memo
|
||||
want int32
|
||||
}{
|
||||
{
|
||||
name: "nil memo returns 0",
|
||||
memo: nil, parentMemo: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "memo without parent returns memo CreatorID",
|
||||
memo: &store.Memo{CreatorID: 5},
|
||||
parentMemo: nil,
|
||||
want: 5,
|
||||
},
|
||||
{
|
||||
name: "memo with parent returns parent CreatorID",
|
||||
memo: &store.Memo{CreatorID: 5},
|
||||
parentMemo: &store.Memo{CreatorID: 9},
|
||||
want: 9,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, resolveSSECreatorID(tc.memo, tc.parentMemo))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMemoReactionSSEEvent(t *testing.T) {
|
||||
parentUID := "parent-uid"
|
||||
|
||||
t.Run("top-level memo reaction", func(t *testing.T) {
|
||||
memo := &store.Memo{CreatorID: 10, Visibility: store.Public}
|
||||
event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", memo, nil)
|
||||
assert.Equal(t, SSEEventReactionUpserted, event.Type)
|
||||
assert.Equal(t, "memos/abc", event.Name)
|
||||
assert.Equal(t, "", event.Parent)
|
||||
assert.Equal(t, store.Public, event.Visibility)
|
||||
assert.Equal(t, int32(10), event.CreatorID)
|
||||
})
|
||||
|
||||
t.Run("reaction on comment is scoped to parent owner", func(t *testing.T) {
|
||||
memo := &store.Memo{
|
||||
CreatorID: 10,
|
||||
Visibility: store.Private,
|
||||
ParentUID: &parentUID,
|
||||
}
|
||||
parentMemo := &store.Memo{CreatorID: 7}
|
||||
event := buildMemoReactionSSEEvent(SSEEventReactionDeleted, "memos/abc", memo, parentMemo)
|
||||
assert.Equal(t, SSEEventReactionDeleted, event.Type)
|
||||
assert.Equal(t, MemoNamePrefix+parentUID, event.Parent)
|
||||
assert.Equal(t, store.Private, event.Visibility)
|
||||
assert.Equal(t, int32(7), event.CreatorID)
|
||||
})
|
||||
|
||||
t.Run("nil memo produces a safe zero-value event", func(t *testing.T) {
|
||||
event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", nil, nil)
|
||||
assert.Equal(t, "memos/abc", event.Name)
|
||||
assert.Equal(t, "", event.Parent)
|
||||
assert.Equal(t, store.Visibility(""), event.Visibility)
|
||||
assert.Equal(t, int32(0), event.CreatorID)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,203 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
// newIntegrationService builds a minimal APIV1Service backed by an in-memory
|
||||
// SQLite database. The store is closed automatically via t.Cleanup.
|
||||
func newIntegrationService(t *testing.T) *APIV1Service {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
st := teststore.NewTestingStore(ctx, t)
|
||||
t.Cleanup(func() { st.Close() })
|
||||
p := &profile.Profile{Demo: true, Data: t.TempDir(), Driver: "sqlite", DSN: ":memory:"}
|
||||
return NewAPIV1Service("test-secret", p, st)
|
||||
}
|
||||
|
||||
// userCtx returns a context that authenticates as the given user.
|
||||
func userCtx(ctx context.Context, userID int32) context.Context {
|
||||
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||
}
|
||||
|
||||
// drainEvents reads all events currently buffered in the channel and returns
|
||||
// them as a string slice. It stops as soon as the channel is empty (non-blocking).
|
||||
func drainEvents(ch <-chan []byte) []string {
|
||||
var out []string
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
out = append(out, string(data))
|
||||
default:
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectEventsFor reads events from ch for the given duration and returns them.
|
||||
func collectEventsFor(ch <-chan []byte, d time.Duration) []string {
|
||||
var out []string
|
||||
deadline := time.After(d)
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
out = append(out, string(data))
|
||||
case <-deadline:
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- context suppression ----
|
||||
|
||||
func TestSuppressSSEContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("default context is not suppressed", func(t *testing.T) {
|
||||
assert.False(t, isSSESuppressed(ctx))
|
||||
})
|
||||
|
||||
t.Run("withSuppressSSE marks context as suppressed", func(t *testing.T) {
|
||||
assert.True(t, isSSESuppressed(withSuppressSSE(ctx)))
|
||||
})
|
||||
|
||||
t.Run("suppression does not bleed into parent context", func(t *testing.T) {
|
||||
suppressed := withSuppressSSE(ctx)
|
||||
_ = suppressed
|
||||
assert.False(t, isSSESuppressed(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// ---- CreateMemoComment double-broadcast fix ----
|
||||
|
||||
func TestCreateMemoComment_NoDuplicateSSEBroadcast(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
// Create an admin so the store is initialised, then a regular commenter.
|
||||
author, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "author", Role: store.RoleAdmin, Email: "author@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
commenter, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "commenter", Role: store.RoleUser, Email: "commenter@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authorCtx := userCtx(ctx, author.ID)
|
||||
commenterCtx := userCtx(ctx, commenter.ID)
|
||||
|
||||
// Create a public memo so the commenter can react.
|
||||
parent, err := svc.CreateMemo(authorCtx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "parent memo", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscribe after the parent memo is created so the memo.created event
|
||||
// for the parent does not pollute the assertion window.
|
||||
client := svc.SSEHub.Subscribe(author.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
// Create a comment. Before the fix, this fired both memo.created (for the
|
||||
// comment memo) and memo.comment.created (for the parent).
|
||||
_, err = svc.CreateMemoComment(commenterCtx, &v1pb.CreateMemoCommentRequest{
|
||||
Name: parent.Name,
|
||||
Comment: &v1pb.Memo{Content: "a comment", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give the synchronous broadcast a moment to land in the buffer, then
|
||||
// collect everything that arrived.
|
||||
events := collectEventsFor(client.events, 150*time.Millisecond)
|
||||
|
||||
require.Len(t, events, 1, "expected exactly one SSE event for a comment creation, got: %v", events)
|
||||
assert.True(t, strings.Contains(events[0], `"memo.comment.created"`),
|
||||
"expected memo.comment.created, got: %s", events[0])
|
||||
}
|
||||
|
||||
// ---- Reaction SSE events carry correct visibility / parent fields ----
|
||||
|
||||
func TestUpsertMemoReaction_SSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &v1pb.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"reaction.upserted"`)
|
||||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDeleteMemoReaction_SSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
reaction, err := svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &v1pb.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "❤️",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.DeleteMemoReaction(uctx, &v1pb.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"reaction.deleted"`)
|
||||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
|
@ -114,9 +114,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
|
|||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
// Register SSE endpoint with same CORS as rest of /api/v1.
|
||||
gwGroup.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret))
|
||||
})
|
||||
RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret)
|
||||
handler := echo.WrapHandler(gwMux)
|
||||
|
||||
gwGroup.Any("/api/v1/*", handler)
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ const MemoView: React.FC<MemoViewProps> = (props: MemoViewProps) => {
|
|||
className="mb-2"
|
||||
cacheKey={`inline-memo-editor-${memoData.name}`}
|
||||
memo={memoData}
|
||||
parentMemoName={memoData.parent || undefined}
|
||||
onConfirm={closeEditor}
|
||||
onCancel={closeEditor}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -12,6 +12,15 @@ const INITIAL_RETRY_DELAY_MS = 1000;
|
|||
const MAX_RETRY_DELAY_MS = 30000;
|
||||
const RETRY_BACKOFF_MULTIPLIER = 2;
|
||||
|
||||
const SSE_EVENT_TYPES = {
|
||||
memoCreated: "memo.created",
|
||||
memoUpdated: "memo.updated",
|
||||
memoDeleted: "memo.deleted",
|
||||
memoCommentCreated: "memo.comment.created",
|
||||
reactionUpserted: "reaction.upserted",
|
||||
reactionDeleted: "reaction.deleted",
|
||||
} as const;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared connection status store (singleton)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -63,6 +72,7 @@ export function useLiveMemoRefresh() {
|
|||
const { currentUser } = useAuth();
|
||||
const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const hasConnectedOnceRef = useRef(false);
|
||||
|
||||
const currentUserName = currentUser?.name;
|
||||
const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]);
|
||||
|
|
@ -101,6 +111,13 @@ export function useLiveMemoRefresh() {
|
|||
// Successfully connected - reset retry delay.
|
||||
retryDelayRef.current = INITIAL_RETRY_DELAY_MS;
|
||||
setSSEStatus("connected");
|
||||
if (hasConnectedOnceRef.current) {
|
||||
// Resync active collaborative views after reconnect because the server may have
|
||||
// dropped events while the client was disconnected or backpressured.
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.all, refetchType: "active" });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats(), refetchType: "active" });
|
||||
}
|
||||
hasConnectedOnceRef.current = true;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
|
@ -175,37 +192,44 @@ export function useLiveMemoRefresh() {
|
|||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SSEChangeEvent {
|
||||
type: string;
|
||||
type: (typeof SSE_EVENT_TYPES)[keyof typeof SSE_EVENT_TYPES];
|
||||
name: string;
|
||||
parent?: string;
|
||||
}
|
||||
|
||||
function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType<typeof useQueryClient>) {
|
||||
switch (event.type) {
|
||||
case "memo.created":
|
||||
case SSE_EVENT_TYPES.memoCreated:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
break;
|
||||
|
||||
case "memo.updated":
|
||||
case SSE_EVENT_TYPES.memoUpdated:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
if (event.parent) {
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) });
|
||||
}
|
||||
break;
|
||||
|
||||
case "memo.deleted":
|
||||
case SSE_EVENT_TYPES.memoDeleted:
|
||||
queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
break;
|
||||
|
||||
case "memo.comment.created":
|
||||
case SSE_EVENT_TYPES.memoCommentCreated:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
break;
|
||||
|
||||
case "reaction.upserted":
|
||||
case "reaction.deleted":
|
||||
case SSE_EVENT_TYPES.reactionUpserted:
|
||||
case SSE_EVENT_TYPES.reactionDeleted:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
if (event.parent) {
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) });
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,6 +117,9 @@ export function useUpdateMemo() {
|
|||
queryClient.setQueryData(memoKeys.detail(updatedMemo.name), updatedMemo);
|
||||
// Invalidate lists to refresh
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
if (updatedMemo.parent) {
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(updatedMemo.parent) });
|
||||
}
|
||||
// Invalidate user stats
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue