fix(api): improve SSE hub design and fix double-broadcast on comments

- Fix duplicate SSE event on comment creation: CreateMemoComment now
  suppresses the redundant memo.created broadcast from the inner
  CreateMemo call, emitting only memo.comment.created
- Extract reaction event-building IIFEs into buildMemoReactionSSEEvent
  helper, removing duplicated inline DB-fetch logic
- Promote resolveSSEAudienceCreatorID from method to free function
  (resolveSSECreatorID) since it never used the receiver
- Add userID to SSE connect/disconnect log lines for traceability
- Change canReceive default from permissive (return true) to
  deny-with-warning for unknown visibility types
- Add comprehensive tests covering all new helpers, visibility edge
  cases, slow-client drop behavior, and the double-broadcast fix

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
memoclaw 2026-03-29 07:33:40 +08:00
parent d720efb6e6
commit c53677fcba
11 changed files with 558 additions and 43 deletions

View File

@ -19,6 +19,19 @@ import (
"github.com/usememos/memos/store"
)
// suppressSSEKey is a context key used to suppress the SSE broadcast from
// CreateMemo when it is called internally (e.g., from CreateMemoComment).
type suppressSSEKey struct{}
func withSuppressSSE(ctx context.Context) context.Context {
return context.WithValue(ctx, suppressSSEKey{}, true)
}
func isSSESuppressed(ctx context.Context) bool {
v, _ := ctx.Value(suppressSSEKey{}).(bool)
return v
}
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
@ -136,11 +149,15 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
}
// Broadcast live refresh event.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoCreated,
Name: memoMessage.Name,
})
// Broadcast live refresh event (skipped when called from CreateMemoComment).
if !isSSESuppressed(ctx) {
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoCreated,
Name: memoMessage.Name,
Visibility: memo.Visibility,
CreatorID: resolveSSECreatorID(memo, nil),
})
}
return memoMessage, nil
}
@ -501,6 +518,10 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
var parentMemo *store.Memo
if memo.ParentUID != nil {
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
}
// Try to dispatch webhook when memo is updated.
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
@ -508,8 +529,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
// Broadcast live refresh event.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoUpdated,
Name: memoMessage.Name,
Type: SSEEventMemoUpdated,
Name: memoMessage.Name,
Parent: memoMessage.GetParent(),
Visibility: memo.Visibility,
CreatorID: resolveSSECreatorID(memo, parentMemo),
})
return memoMessage, nil
@ -583,8 +607,10 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
// Broadcast live refresh event.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoDeleted,
Name: request.Name,
Type: SSEEventMemoDeleted,
Name: request.Name,
Visibility: memo.Visibility,
CreatorID: resolveSSECreatorID(memo, nil),
})
return &emptypb.Empty{}, nil
@ -615,8 +641,9 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Create the memo comment first.
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
// Create the memo comment first; suppress the generic memo.created SSE event
// since CreateMemoComment broadcasts memo.comment.created for the parent instead.
memoComment, err := s.CreateMemo(withSuppressSSE(ctx), &v1pb.CreateMemoRequest{
Memo: request.Comment,
MemoId: request.CommentId,
})
@ -674,8 +701,10 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
// Broadcast live refresh event for the parent memo so subscribers see the new comment.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoCommentCreated,
Name: request.Name,
Type: SSEEventMemoCommentCreated,
Name: request.Name,
Visibility: relatedMemo.Visibility,
CreatorID: relatedMemo.CreatorID,
})
return memoComment, nil

View File

@ -104,10 +104,11 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
}
// Broadcast live refresh event (reaction belongs to a memo).
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventReactionUpserted,
Name: request.Reaction.ContentId,
})
var parentMemo *store.Memo
if memo.ParentUID != nil {
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
}
s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionUpserted, request.Reaction.ContentId, memo, parentMemo))
return reactionMessage, nil
}
@ -148,11 +149,21 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
memoUID, err := ExtractMemoUIDFromName(reaction.ContentID)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Broadcast live refresh event (reaction belongs to a memo).
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventReactionDeleted,
Name: reaction.ContentID,
})
var parentMemo *store.Memo
if memo != nil && memo.ParentUID != nil {
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
}
s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionDeleted, reaction.ContentID, memo, parentMemo))
return &emptypb.Empty{}, nil
}

View File

@ -0,0 +1,40 @@
package v1
import "github.com/usememos/memos/store"
func buildMemoName(uid string) string {
return MemoNamePrefix + uid
}
// resolveSSECreatorID returns the CreatorID used for SSE delivery filtering.
// For a comment memo, it returns the parent memo's CreatorID so that private
// parent-memo events are scoped to the parent owner.
func resolveSSECreatorID(memo *store.Memo, parentMemo *store.Memo) int32 {
if memo == nil {
return 0
}
if parentMemo != nil {
return parentMemo.CreatorID
}
return memo.CreatorID
}
// buildMemoReactionSSEEvent constructs an SSEEvent for a reaction on a memo.
// Pass parentMemo when the memo is a comment (memo.ParentUID != nil).
func buildMemoReactionSSEEvent(eventType SSEEventType, contentID string, memo *store.Memo, parentMemo *store.Memo) *SSEEvent {
parent := ""
if memo != nil && memo.ParentUID != nil {
parent = buildMemoName(*memo.ParentUID)
}
visibility := store.Visibility("")
if memo != nil {
visibility = memo.Visibility
}
return &SSEEvent{
Type: eventType,
Name: contentID,
Parent: parent,
Visibility: visibility,
CreatorID: resolveSSECreatorID(memo, parentMemo),
}
}

View File

@ -17,10 +17,14 @@ const (
sseHeartbeatInterval = 30 * time.Second
)
// RegisterSSERoutes registers the SSE endpoint on the given Echo instance.
func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) {
type sseRouteRegistrar interface {
GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) echo.RouteInfo
}
// RegisterSSERoutes registers the SSE endpoint on the given Echo router.
func RegisterSSERoutes(router sseRouteRegistrar, hub *SSEHub, storeInstance *store.Store, secret string) {
authenticator := auth.NewAuthenticator(storeInstance, secret)
echoServer.GET("/api/v1/sse", func(c *echo.Context) error {
router.GET("/api/v1/sse", func(c *echo.Context) error {
return handleSSE(c, hub, authenticator)
})
}
@ -34,6 +38,10 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
if result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
userID, role := getSSEClientIdentity(result)
if userID == 0 {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
// Set SSE headers.
w := c.Response()
@ -49,7 +57,7 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
}
// Subscribe to the hub.
client := hub.Subscribe()
client := hub.Subscribe(userID, role)
defer hub.Unsubscribe(client)
// Create a ticker for heartbeat pings.
@ -58,13 +66,13 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
ctx := c.Request().Context()
slog.Debug("SSE client connected")
slog.Debug("SSE client connected", "userID", userID)
for {
select {
case <-ctx.Done():
// Client disconnected.
slog.Debug("SSE client disconnected")
slog.Debug("SSE client disconnected", "userID", userID)
return nil
case data, ok := <-client.events:
@ -91,3 +99,16 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
}
}
}
func getSSEClientIdentity(result *auth.AuthResult) (int32, store.Role) {
if result == nil {
return 0, store.RoleUser
}
if result.Claims != nil {
return result.Claims.UserID, store.Role(result.Claims.Role)
}
if result.User != nil {
return result.User.ID, result.User.Role
}
return 0, store.RoleUser
}

View File

@ -4,6 +4,8 @@ import (
"encoding/json"
"log/slog"
"sync"
"github.com/usememos/memos/store"
)
// SSEEventType represents the type of change event.
@ -24,6 +26,11 @@ type SSEEvent struct {
// Name is the affected resource name (e.g., "memos/xxxx").
// For reaction events, this is the memo resource name that the reaction belongs to.
Name string `json:"name"`
// Parent is the parent memo resource name when the affected resource is a comment.
Parent string `json:"parent,omitempty"`
// Visibility and CreatorID are used only for server-side delivery filtering.
Visibility store.Visibility `json:"-"`
CreatorID int32 `json:"-"`
}
// JSON returns the JSON representation of the event.
@ -40,6 +47,8 @@ func (e *SSEEvent) JSON() []byte {
// SSEClient represents a single SSE connection.
type SSEClient struct {
events chan []byte
userID int32
role store.Role
}
// SSEHub manages SSE client connections and broadcasts events.
@ -58,10 +67,12 @@ func NewSSEHub() *SSEHub {
// Subscribe registers a new client and returns it.
// The caller must call Unsubscribe when done.
func (h *SSEHub) Subscribe() *SSEClient {
func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient {
c := &SSEClient{
// Buffer a few events so a slow client doesn't block broadcasting.
events: make(chan []byte, 32),
userID: userID,
role: role,
}
h.mu.Lock()
h.clients[c] = struct{}{}
@ -90,6 +101,9 @@ func (h *SSEHub) Broadcast(event *SSEEvent) {
h.mu.RLock()
defer h.mu.RUnlock()
for c := range h.clients {
if !c.canReceive(event) {
continue
}
select {
case c.events <- data:
default:
@ -97,3 +111,15 @@ func (h *SSEHub) Broadcast(event *SSEEvent) {
}
}
}
func (c *SSEClient) canReceive(event *SSEEvent) bool {
switch event.Visibility {
case store.Private:
return c.userID == event.CreatorID || c.role == store.RoleAdmin
case store.Public, store.Protected, "":
return true
default:
slog.Warn("SSE canReceive: unknown visibility type, denying event", "visibility", event.Visibility)
return false
}
}

View File

@ -6,12 +6,36 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
// helpers shared by multiple tests in this file.
func mustReceive(t *testing.T, ch <-chan []byte, within time.Duration) []byte {
t.Helper()
select {
case data := <-ch:
return data
case <-time.After(within):
t.Fatal("timed out waiting for SSE event")
return nil
}
}
func mustNotReceive(t *testing.T, ch <-chan []byte, within time.Duration) {
t.Helper()
select {
case data := <-ch:
t.Fatalf("unexpected SSE event received: %s", data)
case <-time.After(within):
}
}
func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe()
client := hub.Subscribe(1, store.RoleUser)
require.NotNil(t, client)
require.NotNil(t, client.events)
@ -25,7 +49,7 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
func TestSSEHub_Broadcast(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe()
client := hub.Subscribe(1, store.RoleUser)
defer hub.Unsubscribe(client)
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"}
@ -42,9 +66,9 @@ func TestSSEHub_Broadcast(t *testing.T) {
func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
hub := NewSSEHub()
c1 := hub.Subscribe()
c1 := hub.Subscribe(1, store.RoleUser)
defer hub.Unsubscribe(c1)
c2 := hub.Subscribe()
c2 := hub.Subscribe(2, store.RoleUser)
defer hub.Unsubscribe(c2)
event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"}
@ -62,9 +86,144 @@ func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
}
func TestSSEEvent_JSON(t *testing.T) {
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"}
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789", Parent: "memos/123"}
data := e.JSON()
require.NotEmpty(t, data)
assert.Contains(t, string(data), `"type":"memo.updated"`)
assert.Contains(t, string(data), `"name":"memos/789"`)
assert.Contains(t, string(data), `"parent":"memos/123"`)
}
func TestSSEHub_PrivateEventsAreScoped(t *testing.T) {
hub := NewSSEHub()
owner := hub.Subscribe(1, store.RoleUser)
defer hub.Unsubscribe(owner)
other := hub.Subscribe(2, store.RoleUser)
defer hub.Unsubscribe(other)
admin := hub.Subscribe(3, store.RoleAdmin)
defer hub.Unsubscribe(admin)
hub.Broadcast(&SSEEvent{
Type: SSEEventMemoUpdated,
Name: "memos/private",
Visibility: store.Private,
CreatorID: 1,
})
select {
case <-owner.events:
case <-time.After(time.Second):
t.Fatal("owner should receive private event")
}
select {
case <-admin.events:
case <-time.After(time.Second):
t.Fatal("admin should receive private event")
}
select {
case <-other.events:
t.Fatal("non-owner should not receive private event")
case <-time.After(100 * time.Millisecond):
}
}
func TestSSEClient_CanReceive_UnknownVisibility(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe(1, store.RoleUser)
defer hub.Unsubscribe(client)
// An event with an unrecognised visibility value should be denied (safe default).
hub.Broadcast(&SSEEvent{
Type: SSEEventMemoUpdated,
Name: "memos/unknown-vis",
Visibility: store.Visibility("CUSTOM"),
})
mustNotReceive(t, client.events, 100*time.Millisecond)
}
func TestSSEHub_SlowClientEventsDropped(t *testing.T) {
hub := NewSSEHub()
// Subscribe but never read, so the channel fills up.
slow := hub.Subscribe(1, store.RoleUser)
defer hub.Unsubscribe(slow)
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/x"}
// Send more events than the buffer capacity (32).
for range 40 {
hub.Broadcast(event) // must not block
}
// At most 32 events should have been queued; the rest were silently dropped.
assert.LessOrEqual(t, len(slow.events), 32)
}
func TestResolveSSECreatorID(t *testing.T) {
tests := []struct {
name string
memo *store.Memo
parentMemo *store.Memo
want int32
}{
{
name: "nil memo returns 0",
memo: nil, parentMemo: nil,
want: 0,
},
{
name: "memo without parent returns memo CreatorID",
memo: &store.Memo{CreatorID: 5},
parentMemo: nil,
want: 5,
},
{
name: "memo with parent returns parent CreatorID",
memo: &store.Memo{CreatorID: 5},
parentMemo: &store.Memo{CreatorID: 9},
want: 9,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, resolveSSECreatorID(tc.memo, tc.parentMemo))
})
}
}
func TestBuildMemoReactionSSEEvent(t *testing.T) {
parentUID := "parent-uid"
t.Run("top-level memo reaction", func(t *testing.T) {
memo := &store.Memo{CreatorID: 10, Visibility: store.Public}
event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", memo, nil)
assert.Equal(t, SSEEventReactionUpserted, event.Type)
assert.Equal(t, "memos/abc", event.Name)
assert.Equal(t, "", event.Parent)
assert.Equal(t, store.Public, event.Visibility)
assert.Equal(t, int32(10), event.CreatorID)
})
t.Run("reaction on comment is scoped to parent owner", func(t *testing.T) {
memo := &store.Memo{
CreatorID: 10,
Visibility: store.Private,
ParentUID: &parentUID,
}
parentMemo := &store.Memo{CreatorID: 7}
event := buildMemoReactionSSEEvent(SSEEventReactionDeleted, "memos/abc", memo, parentMemo)
assert.Equal(t, SSEEventReactionDeleted, event.Type)
assert.Equal(t, MemoNamePrefix+parentUID, event.Parent)
assert.Equal(t, store.Private, event.Visibility)
assert.Equal(t, int32(7), event.CreatorID)
})
t.Run("nil memo produces a safe zero-value event", func(t *testing.T) {
event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", nil, nil)
assert.Equal(t, "memos/abc", event.Name)
assert.Equal(t, "", event.Parent)
assert.Equal(t, store.Visibility(""), event.Visibility)
assert.Equal(t, int32(0), event.CreatorID)
})
}

View File

@ -0,0 +1,203 @@
package v1
import (
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/profile"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test"
)
// newIntegrationService builds a minimal APIV1Service backed by an in-memory
// SQLite database. The store is closed automatically via t.Cleanup.
func newIntegrationService(t *testing.T) *APIV1Service {
t.Helper()
ctx := context.Background()
st := teststore.NewTestingStore(ctx, t)
t.Cleanup(func() { st.Close() })
p := &profile.Profile{Demo: true, Data: t.TempDir(), Driver: "sqlite", DSN: ":memory:"}
return NewAPIV1Service("test-secret", p, st)
}
// userCtx returns a context that authenticates as the given user.
func userCtx(ctx context.Context, userID int32) context.Context {
return context.WithValue(ctx, auth.UserIDContextKey, userID)
}
// drainEvents reads all events currently buffered in the channel and returns
// them as a string slice. It stops as soon as the channel is empty (non-blocking).
func drainEvents(ch <-chan []byte) []string {
var out []string
for {
select {
case data := <-ch:
out = append(out, string(data))
default:
return out
}
}
}
// collectEventsFor reads events from ch for the given duration and returns them.
func collectEventsFor(ch <-chan []byte, d time.Duration) []string {
var out []string
deadline := time.After(d)
for {
select {
case data := <-ch:
out = append(out, string(data))
case <-deadline:
return out
}
}
}
// ---- context suppression ----
func TestSuppressSSEContext(t *testing.T) {
ctx := context.Background()
t.Run("default context is not suppressed", func(t *testing.T) {
assert.False(t, isSSESuppressed(ctx))
})
t.Run("withSuppressSSE marks context as suppressed", func(t *testing.T) {
assert.True(t, isSSESuppressed(withSuppressSSE(ctx)))
})
t.Run("suppression does not bleed into parent context", func(t *testing.T) {
suppressed := withSuppressSSE(ctx)
_ = suppressed
assert.False(t, isSSESuppressed(ctx))
})
}
// ---- CreateMemoComment double-broadcast fix ----
func TestCreateMemoComment_NoDuplicateSSEBroadcast(t *testing.T) {
ctx := context.Background()
svc := newIntegrationService(t)
// Create an admin so the store is initialised, then a regular commenter.
author, err := svc.Store.CreateUser(ctx, &store.User{
Username: "author", Role: store.RoleAdmin, Email: "author@example.com",
})
require.NoError(t, err)
commenter, err := svc.Store.CreateUser(ctx, &store.User{
Username: "commenter", Role: store.RoleUser, Email: "commenter@example.com",
})
require.NoError(t, err)
authorCtx := userCtx(ctx, author.ID)
commenterCtx := userCtx(ctx, commenter.ID)
// Create a public memo so the commenter can react.
parent, err := svc.CreateMemo(authorCtx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{Content: "parent memo", Visibility: v1pb.Visibility_PUBLIC},
})
require.NoError(t, err)
// Subscribe after the parent memo is created so the memo.created event
// for the parent does not pollute the assertion window.
client := svc.SSEHub.Subscribe(author.ID, store.RoleAdmin)
defer svc.SSEHub.Unsubscribe(client)
// Create a comment. Before the fix, this fired both memo.created (for the
// comment memo) and memo.comment.created (for the parent).
_, err = svc.CreateMemoComment(commenterCtx, &v1pb.CreateMemoCommentRequest{
Name: parent.Name,
Comment: &v1pb.Memo{Content: "a comment", Visibility: v1pb.Visibility_PUBLIC},
})
require.NoError(t, err)
// Give the synchronous broadcast a moment to land in the buffer, then
// collect everything that arrived.
events := collectEventsFor(client.events, 150*time.Millisecond)
require.Len(t, events, 1, "expected exactly one SSE event for a comment creation, got: %v", events)
assert.True(t, strings.Contains(events[0], `"memo.comment.created"`),
"expected memo.comment.created, got: %s", events[0])
}
// ---- Reaction SSE events carry correct visibility / parent fields ----
func TestUpsertMemoReaction_SSEEvent(t *testing.T) {
ctx := context.Background()
svc := newIntegrationService(t)
user, err := svc.Store.CreateUser(ctx, &store.User{
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
})
require.NoError(t, err)
uctx := userCtx(ctx, user.ID)
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC},
})
require.NoError(t, err)
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
defer svc.SSEHub.Unsubscribe(client)
_, err = svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &v1pb.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
data := mustReceive(t, client.events, time.Second)
payload := string(data)
assert.Contains(t, payload, `"reaction.upserted"`)
assert.Contains(t, payload, memo.Name)
mustNotReceive(t, client.events, 100*time.Millisecond)
}
func TestDeleteMemoReaction_SSEEvent(t *testing.T) {
ctx := context.Background()
svc := newIntegrationService(t)
user, err := svc.Store.CreateUser(ctx, &store.User{
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
})
require.NoError(t, err)
uctx := userCtx(ctx, user.ID)
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC},
})
require.NoError(t, err)
reaction, err := svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &v1pb.Reaction{
ContentId: memo.Name,
ReactionType: "❤️",
},
})
require.NoError(t, err)
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
defer svc.SSEHub.Unsubscribe(client)
_, err = svc.DeleteMemoReaction(uctx, &v1pb.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
data := mustReceive(t, client.events, time.Second)
payload := string(data)
assert.Contains(t, payload, `"reaction.deleted"`)
assert.Contains(t, payload, memo.Name)
mustNotReceive(t, client.events, 100*time.Millisecond)
}

View File

@ -114,9 +114,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
AllowOrigins: []string{"*"},
}))
// Register SSE endpoint with same CORS as rest of /api/v1.
gwGroup.GET("/api/v1/sse", func(c *echo.Context) error {
return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret))
})
RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret)
handler := echo.WrapHandler(gwMux)
gwGroup.Any("/api/v1/*", handler)

View File

@ -77,6 +77,7 @@ const MemoView: React.FC<MemoViewProps> = (props: MemoViewProps) => {
className="mb-2"
cacheKey={`inline-memo-editor-${memoData.name}`}
memo={memoData}
parentMemoName={memoData.parent || undefined}
onConfirm={closeEditor}
onCancel={closeEditor}
/>

View File

@ -12,6 +12,15 @@ const INITIAL_RETRY_DELAY_MS = 1000;
const MAX_RETRY_DELAY_MS = 30000;
const RETRY_BACKOFF_MULTIPLIER = 2;
const SSE_EVENT_TYPES = {
memoCreated: "memo.created",
memoUpdated: "memo.updated",
memoDeleted: "memo.deleted",
memoCommentCreated: "memo.comment.created",
reactionUpserted: "reaction.upserted",
reactionDeleted: "reaction.deleted",
} as const;
// ---------------------------------------------------------------------------
// Shared connection status store (singleton)
// ---------------------------------------------------------------------------
@ -63,6 +72,7 @@ export function useLiveMemoRefresh() {
const { currentUser } = useAuth();
const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS);
const abortControllerRef = useRef<AbortController | null>(null);
const hasConnectedOnceRef = useRef(false);
const currentUserName = currentUser?.name;
const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]);
@ -101,6 +111,13 @@ export function useLiveMemoRefresh() {
// Successfully connected - reset retry delay.
retryDelayRef.current = INITIAL_RETRY_DELAY_MS;
setSSEStatus("connected");
if (hasConnectedOnceRef.current) {
// Resync active collaborative views after reconnect because the server may have
// dropped events while the client was disconnected or backpressured.
queryClient.invalidateQueries({ queryKey: memoKeys.all, refetchType: "active" });
queryClient.invalidateQueries({ queryKey: userKeys.stats(), refetchType: "active" });
}
hasConnectedOnceRef.current = true;
const reader = response.body.getReader();
const decoder = new TextDecoder();
@ -175,37 +192,44 @@ export function useLiveMemoRefresh() {
// ---------------------------------------------------------------------------
interface SSEChangeEvent {
type: string;
type: (typeof SSE_EVENT_TYPES)[keyof typeof SSE_EVENT_TYPES];
name: string;
parent?: string;
}
function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType<typeof useQueryClient>) {
switch (event.type) {
case "memo.created":
case SSE_EVENT_TYPES.memoCreated:
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
break;
case "memo.updated":
case SSE_EVENT_TYPES.memoUpdated:
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
if (event.parent) {
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) });
}
break;
case "memo.deleted":
case SSE_EVENT_TYPES.memoDeleted:
queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
break;
case "memo.comment.created":
case SSE_EVENT_TYPES.memoCommentCreated:
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
break;
case "reaction.upserted":
case "reaction.deleted":
case SSE_EVENT_TYPES.reactionUpserted:
case SSE_EVENT_TYPES.reactionDeleted:
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
if (event.parent) {
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) });
}
break;
}
}

View File

@ -117,6 +117,9 @@ export function useUpdateMemo() {
queryClient.setQueryData(memoKeys.detail(updatedMemo.name), updatedMemo);
// Invalidate lists to refresh
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
if (updatedMemo.parent) {
queryClient.invalidateQueries({ queryKey: memoKeys.comments(updatedMemo.parent) });
}
// Invalidate user stats
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
},