mirror of https://github.com/usememos/memos.git
feat: add live refresh via Server-Sent Events (SSE) with visual indicator (#5638)
Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: milvasic <milvasic@users.noreply.github.com>
This commit is contained in:
parent
a69e405c95
commit
ea0892a8b2
|
|
@ -141,6 +141,12 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCreated,
|
||||
Name: memoMessage.Name,
|
||||
})
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
|
|
@ -471,6 +477,12 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
})
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
|
|
@ -539,6 +551,12 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoDeleted,
|
||||
Name: request.Name,
|
||||
})
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,12 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
|
|||
|
||||
reactionMessage := convertReactionFromStore(reaction)
|
||||
|
||||
// Broadcast live refresh event (reaction belongs to a memo).
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventReactionUpserted,
|
||||
Name: request.Reaction.ContentId,
|
||||
})
|
||||
|
||||
return reactionMessage, nil
|
||||
}
|
||||
|
||||
|
|
@ -136,6 +142,12 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del
|
|||
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
|
||||
}
|
||||
|
||||
// Broadcast live refresh event (reaction belongs to a memo).
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventReactionDeleted,
|
||||
Name: reaction.ContentID,
|
||||
})
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// sseHeartbeatInterval is the interval between heartbeat pings to keep the connection alive.
|
||||
sseHeartbeatInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// RegisterSSERoutes registers the SSE endpoint on the given Echo instance.
|
||||
func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) {
|
||||
authenticator := auth.NewAuthenticator(storeInstance, secret)
|
||||
echoServer.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
return handleSSE(c, hub, authenticator)
|
||||
})
|
||||
}
|
||||
|
||||
// handleSSE handles the SSE connection for live memo refresh.
|
||||
// Authentication is done via Bearer token in the Authorization header,
|
||||
// or via the "token" query parameter (for EventSource which cannot set headers).
|
||||
func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) error {
|
||||
// Authenticate the request.
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// Fall back to query parameter for native EventSource support.
|
||||
if token := c.QueryParam("token"); token != "" {
|
||||
authHeader = "Bearer " + token
|
||||
}
|
||||
}
|
||||
|
||||
result := authenticator.Authenticate(c.Request().Context(), authHeader)
|
||||
if result == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
|
||||
// Set SSE headers.
|
||||
w := c.Response()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Flush headers immediately.
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// Subscribe to the hub.
|
||||
client := hub.Subscribe()
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
// Create a ticker for heartbeat pings.
|
||||
heartbeat := time.NewTicker(sseHeartbeatInterval)
|
||||
defer heartbeat.Stop()
|
||||
|
||||
ctx := c.Request().Context()
|
||||
|
||||
slog.Debug("SSE client connected")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Client disconnected.
|
||||
slog.Debug("SSE client disconnected")
|
||||
return nil
|
||||
|
||||
case data, ok := <-client.events:
|
||||
if !ok {
|
||||
// Channel closed, client was unsubscribed.
|
||||
return nil
|
||||
}
|
||||
// Write SSE event.
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
|
||||
return nil
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
case <-heartbeat.C:
|
||||
// Send a heartbeat comment to keep the connection alive.
|
||||
if _, err := fmt.Fprint(w, ": heartbeat\n\n"); err != nil {
|
||||
return nil
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SSEEventType represents the type of change event.
|
||||
type SSEEventType string
|
||||
|
||||
const (
|
||||
SSEEventMemoCreated SSEEventType = "memo.created"
|
||||
SSEEventMemoUpdated SSEEventType = "memo.updated"
|
||||
SSEEventMemoDeleted SSEEventType = "memo.deleted"
|
||||
SSEEventReactionUpserted SSEEventType = "reaction.upserted"
|
||||
SSEEventReactionDeleted SSEEventType = "reaction.deleted"
|
||||
)
|
||||
|
||||
// SSEEvent represents a change event sent to SSE clients.
|
||||
type SSEEvent struct {
|
||||
Type SSEEventType `json:"type"`
|
||||
// Name is the affected resource name (e.g., "memos/xxxx").
|
||||
// For reaction events, this is the memo resource name that the reaction belongs to.
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// JSON returns the JSON representation of the event.
|
||||
// Returns nil if marshaling fails (error is logged).
|
||||
func (e *SSEEvent) JSON() []byte {
|
||||
data, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal SSE event", "err", err, "event", e)
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// SSEClient represents a single SSE connection.
|
||||
type SSEClient struct {
|
||||
events chan []byte
|
||||
}
|
||||
|
||||
// SSEHub manages SSE client connections and broadcasts events.
|
||||
// It is safe for concurrent use.
|
||||
type SSEHub struct {
|
||||
mu sync.RWMutex
|
||||
clients map[*SSEClient]struct{}
|
||||
}
|
||||
|
||||
// NewSSEHub creates a new SSE hub.
|
||||
func NewSSEHub() *SSEHub {
|
||||
return &SSEHub{
|
||||
clients: make(map[*SSEClient]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new client and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (h *SSEHub) Subscribe() *SSEClient {
|
||||
c := &SSEClient{
|
||||
// Buffer a few events so a slow client doesn't block broadcasting.
|
||||
events: make(chan []byte, 32),
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.clients[c] = struct{}{}
|
||||
h.mu.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// Unsubscribe removes a client and closes its channel.
|
||||
func (h *SSEHub) Unsubscribe(c *SSEClient) {
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[c]; ok {
|
||||
delete(h.clients, c)
|
||||
close(c.events)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all connected clients.
|
||||
// Slow clients that have a full buffer will have the event dropped
|
||||
// to avoid blocking the broadcaster.
|
||||
func (h *SSEHub) Broadcast(event *SSEEvent) {
|
||||
data := event.JSON()
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
for c := range h.clients {
|
||||
select {
|
||||
case c.events <- data:
|
||||
default:
|
||||
// Drop event for slow client to avoid blocking.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
|
||||
client := hub.Subscribe()
|
||||
require.NotNil(t, client)
|
||||
require.NotNil(t, client.events)
|
||||
|
||||
// Unsubscribe removes the client and closes the channel.
|
||||
hub.Unsubscribe(client)
|
||||
|
||||
// Channel should be closed.
|
||||
_, ok := <-client.events
|
||||
assert.False(t, ok, "channel should be closed after Unsubscribe")
|
||||
}
|
||||
|
||||
func TestSSEHub_Broadcast(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
client := hub.Subscribe()
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"}
|
||||
hub.Broadcast(event)
|
||||
|
||||
select {
|
||||
case data := <-client.events:
|
||||
assert.Contains(t, string(data), `"type":"memo.created"`)
|
||||
assert.Contains(t, string(data), `"name":"memos/123"`)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected to receive event within 1s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
c1 := hub.Subscribe()
|
||||
defer hub.Unsubscribe(c1)
|
||||
c2 := hub.Subscribe()
|
||||
defer hub.Unsubscribe(c2)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"}
|
||||
hub.Broadcast(event)
|
||||
|
||||
for _, ch := range []chan []byte{c1.events, c2.events} {
|
||||
select {
|
||||
case data := <-ch:
|
||||
assert.Contains(t, string(data), "memo.deleted")
|
||||
assert.Contains(t, string(data), "memos/456")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected to receive event within 1s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEEvent_JSON(t *testing.T) {
|
||||
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"}
|
||||
data := e.JSON()
|
||||
require.NotEmpty(t, data)
|
||||
assert.Contains(t, string(data), `"type":"memo.updated"`)
|
||||
assert.Contains(t, string(data), `"name":"memos/789"`)
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
)
|
||||
|
||||
func TestSSEHandler_Authentication(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "sse-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte(ts.Secret),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
e := echo.New()
|
||||
apiv1.RegisterSSERoutes(e, ts.Service.SSEHub, ts.Store, ts.Secret)
|
||||
|
||||
t.Run("no token returns 401", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
})
|
||||
|
||||
t.Run("invalid token returns 401", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
})
|
||||
|
||||
t.Run("valid token returns 200 and stream", func(t *testing.T) {
|
||||
// Use a cancellable context so we can close the SSE connection after
|
||||
// confirming the headers, preventing the handler's event loop from
|
||||
// blocking the test indefinitely.
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil).WithContext(reqCtx)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rec := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
e.ServeHTTP(rec, req)
|
||||
}()
|
||||
// Cancel the context to signal client disconnect, which exits the SSE loop.
|
||||
cancel()
|
||||
<-done
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "text/event-stream", rec.Header().Get("Content-Type"))
|
||||
})
|
||||
|
||||
t.Run("token in query param returns 200", func(t *testing.T) {
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse?token="+token, nil).WithContext(reqCtx)
|
||||
rec := httptest.NewRecorder()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
e.ServeHTTP(rec, req)
|
||||
}()
|
||||
cancel()
|
||||
<-done
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
})
|
||||
}
|
||||
|
|
@ -46,6 +46,7 @@ func NewTestService(t *testing.T) *TestService {
|
|||
Profile: testProfile,
|
||||
Store: testStore,
|
||||
MarkdownService: markdownService,
|
||||
SSEHub: apiv1.NewSSEHub(),
|
||||
}
|
||||
|
||||
return &TestService{
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ type APIV1Service struct {
|
|||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
MarkdownService markdown.Service
|
||||
SSEHub *SSEHub
|
||||
|
||||
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
|
||||
thumbnailSemaphore *semaphore.Weighted
|
||||
|
|
@ -45,6 +46,7 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
|
|||
Profile: profile,
|
||||
Store: store,
|
||||
MarkdownService: markdownService,
|
||||
SSEHub: NewSSEHub(),
|
||||
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
|
||||
}
|
||||
}
|
||||
|
|
@ -115,6 +117,10 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
|
|||
gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
// Register SSE endpoint with same CORS as rest of /api/v1.
|
||||
gwGroup.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret))
|
||||
})
|
||||
handler := echo.WrapHandler(gwMux)
|
||||
|
||||
gwGroup.Any("/api/v1/*", handler)
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ Parses data URI to extract MIME type and base64 data.
|
|||
## Dependencies
|
||||
|
||||
### External Packages
|
||||
- `github.com/labstack/echo/v4` - HTTP router and middleware
|
||||
- `github.com/labstack/echo/v5` - HTTP router and middleware
|
||||
- `github.com/golang-jwt/jwt/v5` - JWT parsing and validation
|
||||
- `github.com/disintegration/imaging` - Image thumbnail generation
|
||||
- `golang.org/x/sync/semaphore` - Concurrency control for thumbnails
|
||||
|
|
|
|||
|
|
@ -75,7 +75,8 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
|
|||
|
||||
// Create and register RSS routes (needs markdown service from apiV1Service).
|
||||
rss.NewRSSService(s.Profile, s.Store, apiV1Service.MarkdownService).RegisterRoutes(rootGroup)
|
||||
// Register gRPC gateway as api v1.
|
||||
|
||||
// Register gRPC gateway as api v1 (includes SSE endpoint on CORS-enabled group).
|
||||
if err := apiV1Service.RegisterGateway(ctx, echoServer); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to register gRPC gateway")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { Routes } from "@/router";
|
|||
import { UserNotification_Status } from "@/types/proto/api/v1/user_service_pb";
|
||||
import { useTranslate } from "@/utils/i18n";
|
||||
import MemosLogo from "./MemosLogo";
|
||||
import SSEStatusIndicator from "./SSEStatusIndicator";
|
||||
import UserMenu from "./UserMenu";
|
||||
|
||||
interface NavLinkItem {
|
||||
|
|
@ -114,7 +115,10 @@ const Navigation = (props: Props) => {
|
|||
))}
|
||||
</div>
|
||||
{currentUser && (
|
||||
<div className={cn("w-full flex flex-col justify-end", props.collapsed ? "items-center" : "items-start pl-3")}>
|
||||
<div className={cn("w-full flex flex-col justify-end gap-1", props.collapsed ? "items-center" : "items-start pl-3")}>
|
||||
<div className={cn("flex items-center", props.collapsed ? "justify-center" : "pl-1")}>
|
||||
<SSEStatusIndicator />
|
||||
</div>
|
||||
<UserMenu collapsed={collapsed} />
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
import { useSSEConnectionStatus } from "@/hooks/useLiveMemoRefresh";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "./ui/tooltip";
|
||||
|
||||
/**
|
||||
* A small colored dot that indicates the SSE live-refresh connection status.
|
||||
* - Green = connected (live updates active)
|
||||
* - Yellow/pulsing = connecting
|
||||
* - Red = disconnected (updates not live)
|
||||
*/
|
||||
const SSEStatusIndicator = () => {
|
||||
const status = useSSEConnectionStatus();
|
||||
|
||||
const label =
|
||||
status === "connected" ? "Live updates active" : status === "connecting" ? "Connecting to live updates..." : "Live updates unavailable";
|
||||
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="inline-flex items-center justify-center size-5 cursor-default" aria-label={label}>
|
||||
<span
|
||||
className={cn(
|
||||
"block size-2 rounded-full transition-colors",
|
||||
status === "connected" && "bg-green-500",
|
||||
status === "connecting" && "bg-yellow-500 animate-pulse",
|
||||
status === "disconnected" && "bg-red-500",
|
||||
)}
|
||||
/>
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{label}</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default SSEStatusIndicator;
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useRef, useSyncExternalStore } from "react";
|
||||
import { getAccessToken } from "@/auth-state";
|
||||
import { useAuth } from "@/contexts/AuthContext";
|
||||
import { memoKeys } from "@/hooks/useMemoQueries";
|
||||
import { userKeys } from "@/hooks/useUserQueries";
|
||||
|
||||
/**
|
||||
* Reconnection parameters for SSE connection.
|
||||
*/
|
||||
const INITIAL_RETRY_DELAY_MS = 1000;
|
||||
const MAX_RETRY_DELAY_MS = 30000;
|
||||
const RETRY_BACKOFF_MULTIPLIER = 2;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared connection status store (singleton)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export type SSEConnectionStatus = "connected" | "disconnected" | "connecting";
|
||||
|
||||
type Listener = () => void;
|
||||
|
||||
let _status: SSEConnectionStatus = "disconnected";
|
||||
const _listeners = new Set<Listener>();
|
||||
|
||||
function getSSEStatus(): SSEConnectionStatus {
|
||||
return _status;
|
||||
}
|
||||
|
||||
function setSSEStatus(s: SSEConnectionStatus) {
|
||||
if (_status !== s) {
|
||||
_status = s;
|
||||
_listeners.forEach((l) => l());
|
||||
}
|
||||
}
|
||||
|
||||
function subscribeSSEStatus(listener: Listener): () => void {
|
||||
_listeners.add(listener);
|
||||
return () => _listeners.delete(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* React hook that returns the current SSE connection status.
|
||||
* Re-renders the component whenever the status changes.
|
||||
*/
|
||||
export function useSSEConnectionStatus(): SSEConnectionStatus {
|
||||
return useSyncExternalStore(subscribeSSEStatus, getSSEStatus, getSSEStatus);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main hook
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* useLiveMemoRefresh connects to the server's SSE endpoint and
|
||||
* invalidates relevant React Query caches when change events
|
||||
* (memos, reactions) are received.
|
||||
*
|
||||
* This enables real-time updates across all open instances of the app.
|
||||
*/
|
||||
export function useLiveMemoRefresh() {
|
||||
const queryClient = useQueryClient();
|
||||
const { currentUser } = useAuth();
|
||||
const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]);
|
||||
|
||||
useEffect(() => {
|
||||
let mounted = true;
|
||||
let retryTimeout: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
const connect = async () => {
|
||||
if (!mounted) return;
|
||||
|
||||
const token = getAccessToken();
|
||||
if (!token) {
|
||||
setSSEStatus("disconnected");
|
||||
// Not logged in; do not retry. Effect will re-run when currentUser is set (login).
|
||||
return;
|
||||
}
|
||||
|
||||
setSSEStatus("connecting");
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/v1/sse", {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
signal: abortController.signal,
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(`SSE connection failed: ${response.status}`);
|
||||
}
|
||||
|
||||
// Successfully connected - reset retry delay.
|
||||
retryDelayRef.current = INITIAL_RETRY_DELAY_MS;
|
||||
setSSEStatus("connected");
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (mounted) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// Process complete SSE messages (separated by double newlines).
|
||||
const messages = buffer.split("\n\n");
|
||||
// Keep the last incomplete chunk in the buffer.
|
||||
buffer = messages.pop() || "";
|
||||
|
||||
for (const message of messages) {
|
||||
if (!message.trim()) continue;
|
||||
|
||||
// Parse SSE format: lines starting with "data: " contain JSON payload.
|
||||
// Lines starting with ":" are comments (heartbeats).
|
||||
for (const line of message.split("\n")) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const jsonStr = line.slice(6);
|
||||
try {
|
||||
const event = JSON.parse(jsonStr) as SSEChangeEvent;
|
||||
handleEvent(event);
|
||||
} catch {
|
||||
// Ignore malformed JSON.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") {
|
||||
// Intentional abort, don't reconnect.
|
||||
setSSEStatus("disconnected");
|
||||
return;
|
||||
}
|
||||
// Connection lost or failed - reconnect with backoff.
|
||||
}
|
||||
|
||||
setSSEStatus("disconnected");
|
||||
|
||||
// Reconnect with exponential backoff.
|
||||
if (mounted) {
|
||||
const delay = retryDelayRef.current;
|
||||
retryDelayRef.current = Math.min(delay * RETRY_BACKOFF_MULTIPLIER, MAX_RETRY_DELAY_MS);
|
||||
retryTimeout = setTimeout(connect, delay);
|
||||
}
|
||||
};
|
||||
|
||||
connect();
|
||||
|
||||
return () => {
|
||||
mounted = false;
|
||||
setSSEStatus("disconnected");
|
||||
if (retryTimeout) {
|
||||
clearTimeout(retryTimeout);
|
||||
}
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
}
|
||||
};
|
||||
}, [handleEvent, currentUser]);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event handling
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SSEChangeEvent {
|
||||
type: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType<typeof useQueryClient>) {
|
||||
switch (event.type) {
|
||||
case "memo.created":
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
break;
|
||||
|
||||
case "memo.updated":
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
break;
|
||||
|
||||
case "memo.deleted":
|
||||
queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
break;
|
||||
|
||||
case "reaction.upserted":
|
||||
case "reaction.deleted":
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -12,6 +12,7 @@ import { refreshAccessToken } from "@/connect";
|
|||
import { AuthProvider, useAuth } from "@/contexts/AuthContext";
|
||||
import { InstanceProvider, useInstance } from "@/contexts/InstanceContext";
|
||||
import { ViewProvider } from "@/contexts/ViewContext";
|
||||
import { useLiveMemoRefresh } from "@/hooks/useLiveMemoRefresh";
|
||||
import { useTokenRefreshOnFocus } from "@/hooks/useTokenRefreshOnFocus";
|
||||
import { queryClient } from "@/lib/query-client";
|
||||
import router from "./router";
|
||||
|
|
@ -46,6 +47,9 @@ function AppInitializer({ children }: { children: React.ReactNode }) {
|
|||
// Related: https://github.com/usememos/memos/issues/5589
|
||||
useTokenRefreshOnFocus(refreshAccessToken, !!currentUser);
|
||||
|
||||
// Live refresh: listen for memo changes via SSE and invalidate caches.
|
||||
useLiveMemoRefresh();
|
||||
|
||||
if (!authInitialized || !instanceInitialized) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,12 @@ export default defineConfig({
|
|||
host: "0.0.0.0",
|
||||
port: 3001,
|
||||
proxy: {
|
||||
"^/api/v1/sse": {
|
||||
target: devProxyServer,
|
||||
xfwd: true,
|
||||
// SSE requires no response buffering and longer timeout.
|
||||
timeout: 0,
|
||||
},
|
||||
"^/api": {
|
||||
target: devProxyServer,
|
||||
xfwd: true,
|
||||
|
|
|
|||
Loading…
Reference in New Issue