diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 03890ae15..6897c1ace 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -19,6 +19,19 @@ import ( "github.com/usememos/memos/store" ) +// suppressSSEKey is a context key used to suppress the SSE broadcast from +// CreateMemo when it is called internally (e.g., from CreateMemoComment). +type suppressSSEKey struct{} + +func withSuppressSSE(ctx context.Context) context.Context { + return context.WithValue(ctx, suppressSSEKey{}, true) +} + +func isSSESuppressed(ctx context.Context) bool { + v, _ := ctx.Value(suppressSSEKey{}).(bool) + return v +} + func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) { user, err := s.fetchCurrentUser(ctx) if err != nil { @@ -136,11 +149,15 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err)) } - // Broadcast live refresh event. - s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventMemoCreated, - Name: memoMessage.Name, - }) + // Broadcast live refresh event (skipped when called from CreateMemoComment). + if !isSSESuppressed(ctx) { + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventMemoCreated, + Name: memoMessage.Name, + Visibility: memo.Visibility, + CreatorID: resolveSSECreatorID(memo, nil), + }) + } return memoMessage, nil } @@ -501,6 +518,10 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } + var parentMemo *store.Memo + if memo.ParentUID != nil { + parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) + } // Try to dispatch webhook when memo is updated. if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil { slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err)) @@ -508,8 +529,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR // Broadcast live refresh event. s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventMemoUpdated, - Name: memoMessage.Name, + Type: SSEEventMemoUpdated, + Name: memoMessage.Name, + Parent: memoMessage.GetParent(), + Visibility: memo.Visibility, + CreatorID: resolveSSECreatorID(memo, parentMemo), }) return memoMessage, nil @@ -583,8 +607,10 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR // Broadcast live refresh event. s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventMemoDeleted, - Name: request.Name, + Type: SSEEventMemoDeleted, + Name: request.Name, + Visibility: memo.Visibility, + CreatorID: resolveSSECreatorID(memo, nil), }) return &emptypb.Empty{}, nil @@ -615,8 +641,9 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - // Create the memo comment first. - memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{ + // Create the memo comment first; suppress the generic memo.created SSE event + // since CreateMemoComment broadcasts memo.comment.created for the parent instead. + memoComment, err := s.CreateMemo(withSuppressSSE(ctx), &v1pb.CreateMemoRequest{ Memo: request.Comment, MemoId: request.CommentId, }) @@ -674,8 +701,10 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea // Broadcast live refresh event for the parent memo so subscribers see the new comment. s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventMemoCommentCreated, - Name: request.Name, + Type: SSEEventMemoCommentCreated, + Name: request.Name, + Visibility: relatedMemo.Visibility, + CreatorID: relatedMemo.CreatorID, }) return memoComment, nil diff --git a/server/router/api/v1/reaction_service.go b/server/router/api/v1/reaction_service.go index 624a55740..f25a44d4d 100644 --- a/server/router/api/v1/reaction_service.go +++ b/server/router/api/v1/reaction_service.go @@ -104,10 +104,11 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups } // Broadcast live refresh event (reaction belongs to a memo). - s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventReactionUpserted, - Name: request.Reaction.ContentId, - }) + var parentMemo *store.Memo + if memo.ParentUID != nil { + parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) + } + s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionUpserted, request.Reaction.ContentId, memo, parentMemo)) return reactionMessage, nil } @@ -148,11 +149,21 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del return nil, status.Errorf(codes.Internal, "failed to delete reaction") } + memoUID, err := ExtractMemoUIDFromName(reaction.ContentID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) + } + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get memo") + } + // Broadcast live refresh event (reaction belongs to a memo). - s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventReactionDeleted, - Name: reaction.ContentID, - }) + var parentMemo *store.Memo + if memo != nil && memo.ParentUID != nil { + parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) + } + s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionDeleted, reaction.ContentID, memo, parentMemo)) return &emptypb.Empty{}, nil } diff --git a/server/router/api/v1/sse_event_helpers.go b/server/router/api/v1/sse_event_helpers.go new file mode 100644 index 000000000..40a527451 --- /dev/null +++ b/server/router/api/v1/sse_event_helpers.go @@ -0,0 +1,40 @@ +package v1 + +import "github.com/usememos/memos/store" + +func buildMemoName(uid string) string { + return MemoNamePrefix + uid +} + +// resolveSSECreatorID returns the CreatorID used for SSE delivery filtering. +// For a comment memo, it returns the parent memo's CreatorID so that private +// parent-memo events are scoped to the parent owner. +func resolveSSECreatorID(memo *store.Memo, parentMemo *store.Memo) int32 { + if memo == nil { + return 0 + } + if parentMemo != nil { + return parentMemo.CreatorID + } + return memo.CreatorID +} + +// buildMemoReactionSSEEvent constructs an SSEEvent for a reaction on a memo. +// Pass parentMemo when the memo is a comment (memo.ParentUID != nil). +func buildMemoReactionSSEEvent(eventType SSEEventType, contentID string, memo *store.Memo, parentMemo *store.Memo) *SSEEvent { + parent := "" + if memo != nil && memo.ParentUID != nil { + parent = buildMemoName(*memo.ParentUID) + } + visibility := store.Visibility("") + if memo != nil { + visibility = memo.Visibility + } + return &SSEEvent{ + Type: eventType, + Name: contentID, + Parent: parent, + Visibility: visibility, + CreatorID: resolveSSECreatorID(memo, parentMemo), + } +} diff --git a/server/router/api/v1/sse_handler.go b/server/router/api/v1/sse_handler.go index 07b36d01c..a9f8f444f 100644 --- a/server/router/api/v1/sse_handler.go +++ b/server/router/api/v1/sse_handler.go @@ -17,10 +17,14 @@ const ( sseHeartbeatInterval = 30 * time.Second ) -// RegisterSSERoutes registers the SSE endpoint on the given Echo instance. -func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) { +type sseRouteRegistrar interface { + GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) echo.RouteInfo +} + +// RegisterSSERoutes registers the SSE endpoint on the given Echo router. +func RegisterSSERoutes(router sseRouteRegistrar, hub *SSEHub, storeInstance *store.Store, secret string) { authenticator := auth.NewAuthenticator(storeInstance, secret) - echoServer.GET("/api/v1/sse", func(c *echo.Context) error { + router.GET("/api/v1/sse", func(c *echo.Context) error { return handleSSE(c, hub, authenticator) }) } @@ -34,6 +38,10 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) if result == nil { return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) } + userID, role := getSSEClientIdentity(result) + if userID == 0 { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) + } // Set SSE headers. w := c.Response() @@ -49,7 +57,7 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) } // Subscribe to the hub. - client := hub.Subscribe() + client := hub.Subscribe(userID, role) defer hub.Unsubscribe(client) // Create a ticker for heartbeat pings. @@ -58,13 +66,13 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) ctx := c.Request().Context() - slog.Debug("SSE client connected") + slog.Debug("SSE client connected", "userID", userID) for { select { case <-ctx.Done(): // Client disconnected. - slog.Debug("SSE client disconnected") + slog.Debug("SSE client disconnected", "userID", userID) return nil case data, ok := <-client.events: @@ -91,3 +99,16 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) } } } + +func getSSEClientIdentity(result *auth.AuthResult) (int32, store.Role) { + if result == nil { + return 0, store.RoleUser + } + if result.Claims != nil { + return result.Claims.UserID, store.Role(result.Claims.Role) + } + if result.User != nil { + return result.User.ID, result.User.Role + } + return 0, store.RoleUser +} diff --git a/server/router/api/v1/sse_hub.go b/server/router/api/v1/sse_hub.go index 1fecad8af..a04c2474c 100644 --- a/server/router/api/v1/sse_hub.go +++ b/server/router/api/v1/sse_hub.go @@ -4,6 +4,8 @@ import ( "encoding/json" "log/slog" "sync" + + "github.com/usememos/memos/store" ) // SSEEventType represents the type of change event. @@ -24,6 +26,11 @@ type SSEEvent struct { // Name is the affected resource name (e.g., "memos/xxxx"). // For reaction events, this is the memo resource name that the reaction belongs to. Name string `json:"name"` + // Parent is the parent memo resource name when the affected resource is a comment. + Parent string `json:"parent,omitempty"` + // Visibility and CreatorID are used only for server-side delivery filtering. + Visibility store.Visibility `json:"-"` + CreatorID int32 `json:"-"` } // JSON returns the JSON representation of the event. @@ -40,6 +47,8 @@ func (e *SSEEvent) JSON() []byte { // SSEClient represents a single SSE connection. type SSEClient struct { events chan []byte + userID int32 + role store.Role } // SSEHub manages SSE client connections and broadcasts events. @@ -58,10 +67,12 @@ func NewSSEHub() *SSEHub { // Subscribe registers a new client and returns it. // The caller must call Unsubscribe when done. -func (h *SSEHub) Subscribe() *SSEClient { +func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient { c := &SSEClient{ // Buffer a few events so a slow client doesn't block broadcasting. events: make(chan []byte, 32), + userID: userID, + role: role, } h.mu.Lock() h.clients[c] = struct{}{} @@ -90,6 +101,9 @@ func (h *SSEHub) Broadcast(event *SSEEvent) { h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { + if !c.canReceive(event) { + continue + } select { case c.events <- data: default: @@ -97,3 +111,15 @@ func (h *SSEHub) Broadcast(event *SSEEvent) { } } } + +func (c *SSEClient) canReceive(event *SSEEvent) bool { + switch event.Visibility { + case store.Private: + return c.userID == event.CreatorID || c.role == store.RoleAdmin + case store.Public, store.Protected, "": + return true + default: + slog.Warn("SSE canReceive: unknown visibility type, denying event", "visibility", event.Visibility) + return false + } +} diff --git a/server/router/api/v1/sse_hub_test.go b/server/router/api/v1/sse_hub_test.go index d7d6d2c8a..42e01091a 100644 --- a/server/router/api/v1/sse_hub_test.go +++ b/server/router/api/v1/sse_hub_test.go @@ -6,12 +6,36 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/usememos/memos/store" ) +// helpers shared by multiple tests in this file. + +func mustReceive(t *testing.T, ch <-chan []byte, within time.Duration) []byte { + t.Helper() + select { + case data := <-ch: + return data + case <-time.After(within): + t.Fatal("timed out waiting for SSE event") + return nil + } +} + +func mustNotReceive(t *testing.T, ch <-chan []byte, within time.Duration) { + t.Helper() + select { + case data := <-ch: + t.Fatalf("unexpected SSE event received: %s", data) + case <-time.After(within): + } +} + func TestSSEHub_SubscribeUnsubscribe(t *testing.T) { hub := NewSSEHub() - client := hub.Subscribe() + client := hub.Subscribe(1, store.RoleUser) require.NotNil(t, client) require.NotNil(t, client.events) @@ -25,7 +49,7 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) { func TestSSEHub_Broadcast(t *testing.T) { hub := NewSSEHub() - client := hub.Subscribe() + client := hub.Subscribe(1, store.RoleUser) defer hub.Unsubscribe(client) event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"} @@ -42,9 +66,9 @@ func TestSSEHub_Broadcast(t *testing.T) { func TestSSEHub_BroadcastMultipleClients(t *testing.T) { hub := NewSSEHub() - c1 := hub.Subscribe() + c1 := hub.Subscribe(1, store.RoleUser) defer hub.Unsubscribe(c1) - c2 := hub.Subscribe() + c2 := hub.Subscribe(2, store.RoleUser) defer hub.Unsubscribe(c2) event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"} @@ -62,9 +86,144 @@ func TestSSEHub_BroadcastMultipleClients(t *testing.T) { } func TestSSEEvent_JSON(t *testing.T) { - e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"} + e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789", Parent: "memos/123"} data := e.JSON() require.NotEmpty(t, data) assert.Contains(t, string(data), `"type":"memo.updated"`) assert.Contains(t, string(data), `"name":"memos/789"`) + assert.Contains(t, string(data), `"parent":"memos/123"`) +} + +func TestSSEHub_PrivateEventsAreScoped(t *testing.T) { + hub := NewSSEHub() + owner := hub.Subscribe(1, store.RoleUser) + defer hub.Unsubscribe(owner) + other := hub.Subscribe(2, store.RoleUser) + defer hub.Unsubscribe(other) + admin := hub.Subscribe(3, store.RoleAdmin) + defer hub.Unsubscribe(admin) + + hub.Broadcast(&SSEEvent{ + Type: SSEEventMemoUpdated, + Name: "memos/private", + Visibility: store.Private, + CreatorID: 1, + }) + + select { + case <-owner.events: + case <-time.After(time.Second): + t.Fatal("owner should receive private event") + } + + select { + case <-admin.events: + case <-time.After(time.Second): + t.Fatal("admin should receive private event") + } + + select { + case <-other.events: + t.Fatal("non-owner should not receive private event") + case <-time.After(100 * time.Millisecond): + } +} + +func TestSSEClient_CanReceive_UnknownVisibility(t *testing.T) { + hub := NewSSEHub() + client := hub.Subscribe(1, store.RoleUser) + defer hub.Unsubscribe(client) + + // An event with an unrecognised visibility value should be denied (safe default). + hub.Broadcast(&SSEEvent{ + Type: SSEEventMemoUpdated, + Name: "memos/unknown-vis", + Visibility: store.Visibility("CUSTOM"), + }) + + mustNotReceive(t, client.events, 100*time.Millisecond) +} + +func TestSSEHub_SlowClientEventsDropped(t *testing.T) { + hub := NewSSEHub() + // Subscribe but never read, so the channel fills up. + slow := hub.Subscribe(1, store.RoleUser) + defer hub.Unsubscribe(slow) + + event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/x"} + // Send more events than the buffer capacity (32). + for range 40 { + hub.Broadcast(event) // must not block + } + + // At most 32 events should have been queued; the rest were silently dropped. + assert.LessOrEqual(t, len(slow.events), 32) +} + +func TestResolveSSECreatorID(t *testing.T) { + tests := []struct { + name string + memo *store.Memo + parentMemo *store.Memo + want int32 + }{ + { + name: "nil memo returns 0", + memo: nil, parentMemo: nil, + want: 0, + }, + { + name: "memo without parent returns memo CreatorID", + memo: &store.Memo{CreatorID: 5}, + parentMemo: nil, + want: 5, + }, + { + name: "memo with parent returns parent CreatorID", + memo: &store.Memo{CreatorID: 5}, + parentMemo: &store.Memo{CreatorID: 9}, + want: 9, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, resolveSSECreatorID(tc.memo, tc.parentMemo)) + }) + } +} + +func TestBuildMemoReactionSSEEvent(t *testing.T) { + parentUID := "parent-uid" + + t.Run("top-level memo reaction", func(t *testing.T) { + memo := &store.Memo{CreatorID: 10, Visibility: store.Public} + event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", memo, nil) + assert.Equal(t, SSEEventReactionUpserted, event.Type) + assert.Equal(t, "memos/abc", event.Name) + assert.Equal(t, "", event.Parent) + assert.Equal(t, store.Public, event.Visibility) + assert.Equal(t, int32(10), event.CreatorID) + }) + + t.Run("reaction on comment is scoped to parent owner", func(t *testing.T) { + memo := &store.Memo{ + CreatorID: 10, + Visibility: store.Private, + ParentUID: &parentUID, + } + parentMemo := &store.Memo{CreatorID: 7} + event := buildMemoReactionSSEEvent(SSEEventReactionDeleted, "memos/abc", memo, parentMemo) + assert.Equal(t, SSEEventReactionDeleted, event.Type) + assert.Equal(t, MemoNamePrefix+parentUID, event.Parent) + assert.Equal(t, store.Private, event.Visibility) + assert.Equal(t, int32(7), event.CreatorID) + }) + + t.Run("nil memo produces a safe zero-value event", func(t *testing.T) { + event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", nil, nil) + assert.Equal(t, "memos/abc", event.Name) + assert.Equal(t, "", event.Parent) + assert.Equal(t, store.Visibility(""), event.Visibility) + assert.Equal(t, int32(0), event.CreatorID) + }) } diff --git a/server/router/api/v1/sse_service_test.go b/server/router/api/v1/sse_service_test.go new file mode 100644 index 000000000..93c4831dc --- /dev/null +++ b/server/router/api/v1/sse_service_test.go @@ -0,0 +1,203 @@ +package v1 + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/internal/profile" + v1pb "github.com/usememos/memos/proto/gen/api/v1" + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" + teststore "github.com/usememos/memos/store/test" +) + +// newIntegrationService builds a minimal APIV1Service backed by an in-memory +// SQLite database. The store is closed automatically via t.Cleanup. +func newIntegrationService(t *testing.T) *APIV1Service { + t.Helper() + ctx := context.Background() + st := teststore.NewTestingStore(ctx, t) + t.Cleanup(func() { st.Close() }) + p := &profile.Profile{Demo: true, Data: t.TempDir(), Driver: "sqlite", DSN: ":memory:"} + return NewAPIV1Service("test-secret", p, st) +} + +// userCtx returns a context that authenticates as the given user. +func userCtx(ctx context.Context, userID int32) context.Context { + return context.WithValue(ctx, auth.UserIDContextKey, userID) +} + +// drainEvents reads all events currently buffered in the channel and returns +// them as a string slice. It stops as soon as the channel is empty (non-blocking). +func drainEvents(ch <-chan []byte) []string { + var out []string + for { + select { + case data := <-ch: + out = append(out, string(data)) + default: + return out + } + } +} + +// collectEventsFor reads events from ch for the given duration and returns them. +func collectEventsFor(ch <-chan []byte, d time.Duration) []string { + var out []string + deadline := time.After(d) + for { + select { + case data := <-ch: + out = append(out, string(data)) + case <-deadline: + return out + } + } +} + +// ---- context suppression ---- + +func TestSuppressSSEContext(t *testing.T) { + ctx := context.Background() + + t.Run("default context is not suppressed", func(t *testing.T) { + assert.False(t, isSSESuppressed(ctx)) + }) + + t.Run("withSuppressSSE marks context as suppressed", func(t *testing.T) { + assert.True(t, isSSESuppressed(withSuppressSSE(ctx))) + }) + + t.Run("suppression does not bleed into parent context", func(t *testing.T) { + suppressed := withSuppressSSE(ctx) + _ = suppressed + assert.False(t, isSSESuppressed(ctx)) + }) +} + +// ---- CreateMemoComment double-broadcast fix ---- + +func TestCreateMemoComment_NoDuplicateSSEBroadcast(t *testing.T) { + ctx := context.Background() + svc := newIntegrationService(t) + + // Create an admin so the store is initialised, then a regular commenter. + author, err := svc.Store.CreateUser(ctx, &store.User{ + Username: "author", Role: store.RoleAdmin, Email: "author@example.com", + }) + require.NoError(t, err) + commenter, err := svc.Store.CreateUser(ctx, &store.User{ + Username: "commenter", Role: store.RoleUser, Email: "commenter@example.com", + }) + require.NoError(t, err) + + authorCtx := userCtx(ctx, author.ID) + commenterCtx := userCtx(ctx, commenter.ID) + + // Create a public memo so the commenter can react. + parent, err := svc.CreateMemo(authorCtx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{Content: "parent memo", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + + // Subscribe after the parent memo is created so the memo.created event + // for the parent does not pollute the assertion window. + client := svc.SSEHub.Subscribe(author.ID, store.RoleAdmin) + defer svc.SSEHub.Unsubscribe(client) + + // Create a comment. Before the fix, this fired both memo.created (for the + // comment memo) and memo.comment.created (for the parent). + _, err = svc.CreateMemoComment(commenterCtx, &v1pb.CreateMemoCommentRequest{ + Name: parent.Name, + Comment: &v1pb.Memo{Content: "a comment", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + + // Give the synchronous broadcast a moment to land in the buffer, then + // collect everything that arrived. + events := collectEventsFor(client.events, 150*time.Millisecond) + + require.Len(t, events, 1, "expected exactly one SSE event for a comment creation, got: %v", events) + assert.True(t, strings.Contains(events[0], `"memo.comment.created"`), + "expected memo.comment.created, got: %s", events[0]) +} + +// ---- Reaction SSE events carry correct visibility / parent fields ---- + +func TestUpsertMemoReaction_SSEEvent(t *testing.T) { + ctx := context.Background() + svc := newIntegrationService(t) + + user, err := svc.Store.CreateUser(ctx, &store.User{ + Username: "user", Role: store.RoleAdmin, Email: "user@example.com", + }) + require.NoError(t, err) + uctx := userCtx(ctx, user.ID) + + memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + + client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin) + defer svc.SSEHub.Unsubscribe(client) + + _, err = svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{ + Name: memo.Name, + Reaction: &v1pb.Reaction{ + ContentId: memo.Name, + ReactionType: "👍", + }, + }) + require.NoError(t, err) + + data := mustReceive(t, client.events, time.Second) + payload := string(data) + assert.Contains(t, payload, `"reaction.upserted"`) + assert.Contains(t, payload, memo.Name) + mustNotReceive(t, client.events, 100*time.Millisecond) +} + +func TestDeleteMemoReaction_SSEEvent(t *testing.T) { + ctx := context.Background() + svc := newIntegrationService(t) + + user, err := svc.Store.CreateUser(ctx, &store.User{ + Username: "user", Role: store.RoleAdmin, Email: "user@example.com", + }) + require.NoError(t, err) + uctx := userCtx(ctx, user.ID) + + memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + + reaction, err := svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{ + Name: memo.Name, + Reaction: &v1pb.Reaction{ + ContentId: memo.Name, + ReactionType: "❤️", + }, + }) + require.NoError(t, err) + + client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin) + defer svc.SSEHub.Unsubscribe(client) + + _, err = svc.DeleteMemoReaction(uctx, &v1pb.DeleteMemoReactionRequest{ + Name: reaction.Name, + }) + require.NoError(t, err) + + data := mustReceive(t, client.events, time.Second) + payload := string(data) + assert.Contains(t, payload, `"reaction.deleted"`) + assert.Contains(t, payload, memo.Name) + mustNotReceive(t, client.events, 100*time.Millisecond) +} diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 4d5e5c329..cb0f0a289 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -114,9 +114,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech AllowOrigins: []string{"*"}, })) // Register SSE endpoint with same CORS as rest of /api/v1. - gwGroup.GET("/api/v1/sse", func(c *echo.Context) error { - return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret)) - }) + RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret) handler := echo.WrapHandler(gwMux) gwGroup.Any("/api/v1/*", handler) diff --git a/web/src/components/MemoView/MemoView.tsx b/web/src/components/MemoView/MemoView.tsx index f177657ea..4010ba8da 100644 --- a/web/src/components/MemoView/MemoView.tsx +++ b/web/src/components/MemoView/MemoView.tsx @@ -77,6 +77,7 @@ const MemoView: React.FC = (props: MemoViewProps) => { className="mb-2" cacheKey={`inline-memo-editor-${memoData.name}`} memo={memoData} + parentMemoName={memoData.parent || undefined} onConfirm={closeEditor} onCancel={closeEditor} /> diff --git a/web/src/hooks/useLiveMemoRefresh.ts b/web/src/hooks/useLiveMemoRefresh.ts index 3c0d8bfdb..22fdfcd9e 100644 --- a/web/src/hooks/useLiveMemoRefresh.ts +++ b/web/src/hooks/useLiveMemoRefresh.ts @@ -12,6 +12,15 @@ const INITIAL_RETRY_DELAY_MS = 1000; const MAX_RETRY_DELAY_MS = 30000; const RETRY_BACKOFF_MULTIPLIER = 2; +const SSE_EVENT_TYPES = { + memoCreated: "memo.created", + memoUpdated: "memo.updated", + memoDeleted: "memo.deleted", + memoCommentCreated: "memo.comment.created", + reactionUpserted: "reaction.upserted", + reactionDeleted: "reaction.deleted", +} as const; + // --------------------------------------------------------------------------- // Shared connection status store (singleton) // --------------------------------------------------------------------------- @@ -63,6 +72,7 @@ export function useLiveMemoRefresh() { const { currentUser } = useAuth(); const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS); const abortControllerRef = useRef(null); + const hasConnectedOnceRef = useRef(false); const currentUserName = currentUser?.name; const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]); @@ -101,6 +111,13 @@ export function useLiveMemoRefresh() { // Successfully connected - reset retry delay. retryDelayRef.current = INITIAL_RETRY_DELAY_MS; setSSEStatus("connected"); + if (hasConnectedOnceRef.current) { + // Resync active collaborative views after reconnect because the server may have + // dropped events while the client was disconnected or backpressured. + queryClient.invalidateQueries({ queryKey: memoKeys.all, refetchType: "active" }); + queryClient.invalidateQueries({ queryKey: userKeys.stats(), refetchType: "active" }); + } + hasConnectedOnceRef.current = true; const reader = response.body.getReader(); const decoder = new TextDecoder(); @@ -175,37 +192,44 @@ export function useLiveMemoRefresh() { // --------------------------------------------------------------------------- interface SSEChangeEvent { - type: string; + type: (typeof SSE_EVENT_TYPES)[keyof typeof SSE_EVENT_TYPES]; name: string; + parent?: string; } function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType) { switch (event.type) { - case "memo.created": + case SSE_EVENT_TYPES.memoCreated: queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); queryClient.invalidateQueries({ queryKey: userKeys.stats() }); break; - case "memo.updated": + case SSE_EVENT_TYPES.memoUpdated: queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) }); queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + if (event.parent) { + queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) }); + } break; - case "memo.deleted": + case SSE_EVENT_TYPES.memoDeleted: queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) }); queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); queryClient.invalidateQueries({ queryKey: userKeys.stats() }); break; - case "memo.comment.created": + case SSE_EVENT_TYPES.memoCommentCreated: queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.name) }); queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) }); break; - case "reaction.upserted": - case "reaction.deleted": + case SSE_EVENT_TYPES.reactionUpserted: + case SSE_EVENT_TYPES.reactionDeleted: queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) }); queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + if (event.parent) { + queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) }); + } break; } } diff --git a/web/src/hooks/useMemoQueries.ts b/web/src/hooks/useMemoQueries.ts index e2dbd6ff1..458c5c6b6 100644 --- a/web/src/hooks/useMemoQueries.ts +++ b/web/src/hooks/useMemoQueries.ts @@ -117,6 +117,9 @@ export function useUpdateMemo() { queryClient.setQueryData(memoKeys.detail(updatedMemo.name), updatedMemo); // Invalidate lists to refresh queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + if (updatedMemo.parent) { + queryClient.invalidateQueries({ queryKey: memoKeys.comments(updatedMemo.parent) }); + } // Invalidate user stats queryClient.invalidateQueries({ queryKey: userKeys.stats() }); },