diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index f7f8e7622..3c9aa1ecd 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -141,6 +141,12 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err)) } + // Broadcast live refresh event. + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventMemoCreated, + Name: memoMessage.Name, + }) + return memoMessage, nil } @@ -471,6 +477,12 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err)) } + // Broadcast live refresh event. + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventMemoUpdated, + Name: memoMessage.Name, + }) + return memoMessage, nil } @@ -539,6 +551,12 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR return nil, status.Errorf(codes.Internal, "failed to delete memo") } + // Broadcast live refresh event. + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventMemoDeleted, + Name: request.Name, + }) + return &emptypb.Empty{}, nil } diff --git a/server/router/api/v1/reaction_service.go b/server/router/api/v1/reaction_service.go index a7c7cc3bd..a4c521fe8 100644 --- a/server/router/api/v1/reaction_service.go +++ b/server/router/api/v1/reaction_service.go @@ -97,6 +97,12 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups reactionMessage := convertReactionFromStore(reaction) + // Broadcast live refresh event (reaction belongs to a memo). + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventReactionUpserted, + Name: request.Reaction.ContentId, + }) + return reactionMessage, nil } @@ -136,6 +142,12 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del return nil, status.Errorf(codes.Internal, "failed to delete reaction") } + // Broadcast live refresh event (reaction belongs to a memo). + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventReactionDeleted, + Name: reaction.ContentID, + }) + return &emptypb.Empty{}, nil } diff --git a/server/router/api/v1/sse_handler.go b/server/router/api/v1/sse_handler.go new file mode 100644 index 000000000..09379804c --- /dev/null +++ b/server/router/api/v1/sse_handler.go @@ -0,0 +1,101 @@ +package v1 + +import ( + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/labstack/echo/v5" + + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" +) + +const ( + // sseHeartbeatInterval is the interval between heartbeat pings to keep the connection alive. + sseHeartbeatInterval = 30 * time.Second +) + +// RegisterSSERoutes registers the SSE endpoint on the given Echo instance. +func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) { + authenticator := auth.NewAuthenticator(storeInstance, secret) + echoServer.GET("/api/v1/sse", func(c *echo.Context) error { + return handleSSE(c, hub, authenticator) + }) +} + +// handleSSE handles the SSE connection for live memo refresh. +// Authentication is done via Bearer token in the Authorization header, +// or via the "token" query parameter (for EventSource which cannot set headers). +func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) error { + // Authenticate the request. + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + // Fall back to query parameter for native EventSource support. + if token := c.QueryParam("token"); token != "" { + authHeader = "Bearer " + token + } + } + + result := authenticator.Authenticate(c.Request().Context(), authHeader) + if result == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) + } + + // Set SSE headers. + w := c.Response() + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering + w.WriteHeader(http.StatusOK) + + // Flush headers immediately. + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // Subscribe to the hub. + client := hub.Subscribe() + defer hub.Unsubscribe(client) + + // Create a ticker for heartbeat pings. + heartbeat := time.NewTicker(sseHeartbeatInterval) + defer heartbeat.Stop() + + ctx := c.Request().Context() + + slog.Debug("SSE client connected") + + for { + select { + case <-ctx.Done(): + // Client disconnected. + slog.Debug("SSE client disconnected") + return nil + + case data, ok := <-client.events: + if !ok { + // Channel closed, client was unsubscribed. + return nil + } + // Write SSE event. + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + return nil + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + case <-heartbeat.C: + // Send a heartbeat comment to keep the connection alive. + if _, err := fmt.Fprint(w, ": heartbeat\n\n"); err != nil { + return nil + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + } +} diff --git a/server/router/api/v1/sse_hub.go b/server/router/api/v1/sse_hub.go new file mode 100644 index 000000000..245ac7b56 --- /dev/null +++ b/server/router/api/v1/sse_hub.go @@ -0,0 +1,98 @@ +package v1 + +import ( + "encoding/json" + "log/slog" + "sync" +) + +// SSEEventType represents the type of change event. +type SSEEventType string + +const ( + SSEEventMemoCreated SSEEventType = "memo.created" + SSEEventMemoUpdated SSEEventType = "memo.updated" + SSEEventMemoDeleted SSEEventType = "memo.deleted" + SSEEventReactionUpserted SSEEventType = "reaction.upserted" + SSEEventReactionDeleted SSEEventType = "reaction.deleted" +) + +// SSEEvent represents a change event sent to SSE clients. +type SSEEvent struct { + Type SSEEventType `json:"type"` + // Name is the affected resource name (e.g., "memos/xxxx"). + // For reaction events, this is the memo resource name that the reaction belongs to. + Name string `json:"name"` +} + +// JSON returns the JSON representation of the event. +// Returns nil if marshaling fails (error is logged). +func (e *SSEEvent) JSON() []byte { + data, err := json.Marshal(e) + if err != nil { + slog.Error("failed to marshal SSE event", "err", err, "event", e) + return nil + } + return data +} + +// SSEClient represents a single SSE connection. +type SSEClient struct { + events chan []byte +} + +// SSEHub manages SSE client connections and broadcasts events. +// It is safe for concurrent use. +type SSEHub struct { + mu sync.RWMutex + clients map[*SSEClient]struct{} +} + +// NewSSEHub creates a new SSE hub. +func NewSSEHub() *SSEHub { + return &SSEHub{ + clients: make(map[*SSEClient]struct{}), + } +} + +// Subscribe registers a new client and returns it. +// The caller must call Unsubscribe when done. +func (h *SSEHub) Subscribe() *SSEClient { + c := &SSEClient{ + // Buffer a few events so a slow client doesn't block broadcasting. + events: make(chan []byte, 32), + } + h.mu.Lock() + h.clients[c] = struct{}{} + h.mu.Unlock() + return c +} + +// Unsubscribe removes a client and closes its channel. +func (h *SSEHub) Unsubscribe(c *SSEClient) { + h.mu.Lock() + if _, ok := h.clients[c]; ok { + delete(h.clients, c) + close(c.events) + } + h.mu.Unlock() +} + +// Broadcast sends an event to all connected clients. +// Slow clients that have a full buffer will have the event dropped +// to avoid blocking the broadcaster. +func (h *SSEHub) Broadcast(event *SSEEvent) { + data := event.JSON() + if len(data) == 0 { + return + } + h.mu.RLock() + defer h.mu.RUnlock() + for c := range h.clients { + select { + case c.events <- data: + default: + // Drop event for slow client to avoid blocking. + } + } +} diff --git a/server/router/api/v1/sse_hub_test.go b/server/router/api/v1/sse_hub_test.go new file mode 100644 index 000000000..d7d6d2c8a --- /dev/null +++ b/server/router/api/v1/sse_hub_test.go @@ -0,0 +1,70 @@ +package v1 + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSEHub_SubscribeUnsubscribe(t *testing.T) { + hub := NewSSEHub() + + client := hub.Subscribe() + require.NotNil(t, client) + require.NotNil(t, client.events) + + // Unsubscribe removes the client and closes the channel. + hub.Unsubscribe(client) + + // Channel should be closed. + _, ok := <-client.events + assert.False(t, ok, "channel should be closed after Unsubscribe") +} + +func TestSSEHub_Broadcast(t *testing.T) { + hub := NewSSEHub() + client := hub.Subscribe() + defer hub.Unsubscribe(client) + + event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"} + hub.Broadcast(event) + + select { + case data := <-client.events: + assert.Contains(t, string(data), `"type":"memo.created"`) + assert.Contains(t, string(data), `"name":"memos/123"`) + case <-time.After(time.Second): + t.Fatal("expected to receive event within 1s") + } +} + +func TestSSEHub_BroadcastMultipleClients(t *testing.T) { + hub := NewSSEHub() + c1 := hub.Subscribe() + defer hub.Unsubscribe(c1) + c2 := hub.Subscribe() + defer hub.Unsubscribe(c2) + + event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"} + hub.Broadcast(event) + + for _, ch := range []chan []byte{c1.events, c2.events} { + select { + case data := <-ch: + assert.Contains(t, string(data), "memo.deleted") + assert.Contains(t, string(data), "memos/456") + case <-time.After(time.Second): + t.Fatal("expected to receive event within 1s") + } + } +} + +func TestSSEEvent_JSON(t *testing.T) { + e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"} + data := e.JSON() + require.NotEmpty(t, data) + assert.Contains(t, string(data), `"type":"memo.updated"`) + assert.Contains(t, string(data), `"name":"memos/789"`) +} diff --git a/server/router/api/v1/test/sse_handler_test.go b/server/router/api/v1/test/sse_handler_test.go new file mode 100644 index 000000000..d182b279a --- /dev/null +++ b/server/router/api/v1/test/sse_handler_test.go @@ -0,0 +1,86 @@ +package test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/server/auth" + apiv1 "github.com/usememos/memos/server/router/api/v1" +) + +func TestSSEHandler_Authentication(t *testing.T) { + ctx := context.Background() + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "sse-user") + require.NoError(t, err) + + token, _, err := auth.GenerateAccessTokenV2( + user.ID, + user.Username, + string(user.Role), + string(user.RowStatus), + []byte(ts.Secret), + ) + require.NoError(t, err) + + e := echo.New() + apiv1.RegisterSSERoutes(e, ts.Service.SSEHub, ts.Store, ts.Secret) + + t.Run("no token returns 401", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + }) + + t.Run("invalid token returns 401", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + }) + + t.Run("valid token returns 200 and stream", func(t *testing.T) { + // Use a cancellable context so we can close the SSE connection after + // confirming the headers, preventing the handler's event loop from + // blocking the test indefinitely. + reqCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil).WithContext(reqCtx) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + defer close(done) + e.ServeHTTP(rec, req) + }() + // Cancel the context to signal client disconnect, which exits the SSE loop. + cancel() + <-done + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "text/event-stream", rec.Header().Get("Content-Type")) + }) + + t.Run("token in query param returns 200", func(t *testing.T) { + reqCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + req := httptest.NewRequest(http.MethodGet, "/api/v1/sse?token="+token, nil).WithContext(reqCtx) + rec := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + defer close(done) + e.ServeHTTP(rec, req) + }() + cancel() + <-done + require.Equal(t, http.StatusOK, rec.Code) + }) +} diff --git a/server/router/api/v1/test/test_helper.go b/server/router/api/v1/test/test_helper.go index 779ad2eea..c3afdb38b 100644 --- a/server/router/api/v1/test/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -46,6 +46,7 @@ func NewTestService(t *testing.T) *TestService { Profile: testProfile, Store: testStore, MarkdownService: markdownService, + SSEHub: apiv1.NewSSEHub(), } return &TestService{ diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index cab335c81..79c140f07 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -31,6 +31,7 @@ type APIV1Service struct { Profile *profile.Profile Store *store.Store MarkdownService markdown.Service + SSEHub *SSEHub // thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion thumbnailSemaphore *semaphore.Weighted @@ -45,6 +46,7 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store Profile: profile, Store: store, MarkdownService: markdownService, + SSEHub: NewSSEHub(), thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations } } @@ -115,6 +117,10 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{"*"}, })) + // Register SSE endpoint with same CORS as rest of /api/v1. + gwGroup.GET("/api/v1/sse", func(c *echo.Context) error { + return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret)) + }) handler := echo.WrapHandler(gwMux) gwGroup.Any("/api/v1/*", handler) diff --git a/server/router/fileserver/README.md b/server/router/fileserver/README.md index 6e8a4e3f0..984d41eaa 100644 --- a/server/router/fileserver/README.md +++ b/server/router/fileserver/README.md @@ -184,7 +184,7 @@ Parses data URI to extract MIME type and base64 data. ## Dependencies ### External Packages -- `github.com/labstack/echo/v4` - HTTP router and middleware +- `github.com/labstack/echo/v5` - HTTP router and middleware - `github.com/golang-jwt/jwt/v5` - JWT parsing and validation - `github.com/disintegration/imaging` - Image thumbnail generation - `golang.org/x/sync/semaphore` - Concurrency control for thumbnails diff --git a/server/server.go b/server/server.go index 629ed01b8..593b73832 100644 --- a/server/server.go +++ b/server/server.go @@ -75,7 +75,8 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store // Create and register RSS routes (needs markdown service from apiV1Service). rss.NewRSSService(s.Profile, s.Store, apiV1Service.MarkdownService).RegisterRoutes(rootGroup) - // Register gRPC gateway as api v1. + + // Register gRPC gateway as api v1 (includes SSE endpoint on CORS-enabled group). if err := apiV1Service.RegisterGateway(ctx, echoServer); err != nil { return nil, errors.Wrap(err, "failed to register gRPC gateway") } diff --git a/web/src/components/Navigation.tsx b/web/src/components/Navigation.tsx index 484f0837b..faf07c55b 100644 --- a/web/src/components/Navigation.tsx +++ b/web/src/components/Navigation.tsx @@ -8,6 +8,7 @@ import { Routes } from "@/router"; import { UserNotification_Status } from "@/types/proto/api/v1/user_service_pb"; import { useTranslate } from "@/utils/i18n"; import MemosLogo from "./MemosLogo"; +import SSEStatusIndicator from "./SSEStatusIndicator"; import UserMenu from "./UserMenu"; interface NavLinkItem { @@ -114,7 +115,10 @@ const Navigation = (props: Props) => { ))} {currentUser && ( -
+
+
+ +
)} diff --git a/web/src/components/SSEStatusIndicator.tsx b/web/src/components/SSEStatusIndicator.tsx new file mode 100644 index 000000000..ca4df407d --- /dev/null +++ b/web/src/components/SSEStatusIndicator.tsx @@ -0,0 +1,36 @@ +import { useSSEConnectionStatus } from "@/hooks/useLiveMemoRefresh"; +import { cn } from "@/lib/utils"; +import { Tooltip, TooltipContent, TooltipTrigger } from "./ui/tooltip"; + +/** + * A small colored dot that indicates the SSE live-refresh connection status. + * - Green = connected (live updates active) + * - Yellow/pulsing = connecting + * - Red = disconnected (updates not live) + */ +const SSEStatusIndicator = () => { + const status = useSSEConnectionStatus(); + + const label = + status === "connected" ? "Live updates active" : status === "connecting" ? "Connecting to live updates..." : "Live updates unavailable"; + + return ( + + + + + + + {label} + + ); +}; + +export default SSEStatusIndicator; diff --git a/web/src/hooks/useLiveMemoRefresh.ts b/web/src/hooks/useLiveMemoRefresh.ts new file mode 100644 index 000000000..6b57bce96 --- /dev/null +++ b/web/src/hooks/useLiveMemoRefresh.ts @@ -0,0 +1,204 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback, useEffect, useRef, useSyncExternalStore } from "react"; +import { getAccessToken } from "@/auth-state"; +import { useAuth } from "@/contexts/AuthContext"; +import { memoKeys } from "@/hooks/useMemoQueries"; +import { userKeys } from "@/hooks/useUserQueries"; + +/** + * Reconnection parameters for SSE connection. + */ +const INITIAL_RETRY_DELAY_MS = 1000; +const MAX_RETRY_DELAY_MS = 30000; +const RETRY_BACKOFF_MULTIPLIER = 2; + +// --------------------------------------------------------------------------- +// Shared connection status store (singleton) +// --------------------------------------------------------------------------- + +export type SSEConnectionStatus = "connected" | "disconnected" | "connecting"; + +type Listener = () => void; + +let _status: SSEConnectionStatus = "disconnected"; +const _listeners = new Set(); + +function getSSEStatus(): SSEConnectionStatus { + return _status; +} + +function setSSEStatus(s: SSEConnectionStatus) { + if (_status !== s) { + _status = s; + _listeners.forEach((l) => l()); + } +} + +function subscribeSSEStatus(listener: Listener): () => void { + _listeners.add(listener); + return () => _listeners.delete(listener); +} + +/** + * React hook that returns the current SSE connection status. + * Re-renders the component whenever the status changes. + */ +export function useSSEConnectionStatus(): SSEConnectionStatus { + return useSyncExternalStore(subscribeSSEStatus, getSSEStatus, getSSEStatus); +} + +// --------------------------------------------------------------------------- +// Main hook +// --------------------------------------------------------------------------- + +/** + * useLiveMemoRefresh connects to the server's SSE endpoint and + * invalidates relevant React Query caches when change events + * (memos, reactions) are received. + * + * This enables real-time updates across all open instances of the app. + */ +export function useLiveMemoRefresh() { + const queryClient = useQueryClient(); + const { currentUser } = useAuth(); + const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS); + const abortControllerRef = useRef(null); + + const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]); + + useEffect(() => { + let mounted = true; + let retryTimeout: ReturnType | null = null; + + const connect = async () => { + if (!mounted) return; + + const token = getAccessToken(); + if (!token) { + setSSEStatus("disconnected"); + // Not logged in; do not retry. Effect will re-run when currentUser is set (login). + return; + } + + setSSEStatus("connecting"); + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + try { + const response = await fetch("/api/v1/sse", { + headers: { + Authorization: `Bearer ${token}`, + }, + signal: abortController.signal, + credentials: "include", + }); + + if (!response.ok || !response.body) { + throw new Error(`SSE connection failed: ${response.status}`); + } + + // Successfully connected - reset retry delay. + retryDelayRef.current = INITIAL_RETRY_DELAY_MS; + setSSEStatus("connected"); + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (mounted) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + // Process complete SSE messages (separated by double newlines). + const messages = buffer.split("\n\n"); + // Keep the last incomplete chunk in the buffer. + buffer = messages.pop() || ""; + + for (const message of messages) { + if (!message.trim()) continue; + + // Parse SSE format: lines starting with "data: " contain JSON payload. + // Lines starting with ":" are comments (heartbeats). + for (const line of message.split("\n")) { + if (line.startsWith("data: ")) { + const jsonStr = line.slice(6); + try { + const event = JSON.parse(jsonStr) as SSEChangeEvent; + handleEvent(event); + } catch { + // Ignore malformed JSON. + } + } + } + } + } + } catch (err: unknown) { + if (err instanceof DOMException && err.name === "AbortError") { + // Intentional abort, don't reconnect. + setSSEStatus("disconnected"); + return; + } + // Connection lost or failed - reconnect with backoff. + } + + setSSEStatus("disconnected"); + + // Reconnect with exponential backoff. + if (mounted) { + const delay = retryDelayRef.current; + retryDelayRef.current = Math.min(delay * RETRY_BACKOFF_MULTIPLIER, MAX_RETRY_DELAY_MS); + retryTimeout = setTimeout(connect, delay); + } + }; + + connect(); + + return () => { + mounted = false; + setSSEStatus("disconnected"); + if (retryTimeout) { + clearTimeout(retryTimeout); + } + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + }; + }, [handleEvent, currentUser]); +} + +// --------------------------------------------------------------------------- +// Event handling +// --------------------------------------------------------------------------- + +interface SSEChangeEvent { + type: string; + name: string; +} + +function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType) { + switch (event.type) { + case "memo.created": + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + queryClient.invalidateQueries({ queryKey: userKeys.stats() }); + break; + + case "memo.updated": + queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) }); + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + break; + + case "memo.deleted": + queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) }); + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + queryClient.invalidateQueries({ queryKey: userKeys.stats() }); + break; + + case "reaction.upserted": + case "reaction.deleted": + queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) }); + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + break; + } +} diff --git a/web/src/main.tsx b/web/src/main.tsx index 5d617b210..a87b1a14b 100644 --- a/web/src/main.tsx +++ b/web/src/main.tsx @@ -12,6 +12,7 @@ import { refreshAccessToken } from "@/connect"; import { AuthProvider, useAuth } from "@/contexts/AuthContext"; import { InstanceProvider, useInstance } from "@/contexts/InstanceContext"; import { ViewProvider } from "@/contexts/ViewContext"; +import { useLiveMemoRefresh } from "@/hooks/useLiveMemoRefresh"; import { useTokenRefreshOnFocus } from "@/hooks/useTokenRefreshOnFocus"; import { queryClient } from "@/lib/query-client"; import router from "./router"; @@ -46,6 +47,9 @@ function AppInitializer({ children }: { children: React.ReactNode }) { // Related: https://github.com/usememos/memos/issues/5589 useTokenRefreshOnFocus(refreshAccessToken, !!currentUser); + // Live refresh: listen for memo changes via SSE and invalidate caches. + useLiveMemoRefresh(); + if (!authInitialized || !instanceInitialized) { return null; } diff --git a/web/vite.config.mts b/web/vite.config.mts index 5b63cb382..1433c39f2 100644 --- a/web/vite.config.mts +++ b/web/vite.config.mts @@ -16,6 +16,12 @@ export default defineConfig({ host: "0.0.0.0", port: 3001, proxy: { + "^/api/v1/sse": { + target: devProxyServer, + xfwd: true, + // SSE requires no response buffering and longer timeout. + timeout: 0, + }, "^/api": { target: devProxyServer, xfwd: true,