diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index f5d250a16..47f3a9c36 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -141,6 +141,12 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err)) } + // Broadcast live refresh event. + s.SSEHub.Broadcast(&MemoEvent{ + Type: MemoEventCreated, + Name: memoMessage.Name, + }) + return memoMessage, nil } @@ -471,6 +477,12 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err)) } + // Broadcast live refresh event. + s.SSEHub.Broadcast(&MemoEvent{ + Type: MemoEventUpdated, + Name: memoMessage.Name, + }) + return memoMessage, nil } @@ -539,6 +551,12 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR return nil, status.Errorf(codes.Internal, "failed to delete memo") } + // Broadcast live refresh event. + s.SSEHub.Broadcast(&MemoEvent{ + Type: MemoEventDeleted, + Name: request.Name, + }) + return &emptypb.Empty{}, nil } diff --git a/server/router/api/v1/sse_handler.go b/server/router/api/v1/sse_handler.go new file mode 100644 index 000000000..af9e2389f --- /dev/null +++ b/server/router/api/v1/sse_handler.go @@ -0,0 +1,101 @@ +package v1 + +import ( + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/labstack/echo/v4" + + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" +) + +const ( + // sseHeartbeatInterval is the interval between heartbeat pings to keep the connection alive. + sseHeartbeatInterval = 30 * time.Second +) + +// RegisterSSERoutes registers the SSE endpoint on the given Echo instance. +func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) { + authenticator := auth.NewAuthenticator(storeInstance, secret) + echoServer.GET("/api/v1/sse", func(c echo.Context) error { + return handleSSE(c, hub, authenticator) + }) +} + +// handleSSE handles the SSE connection for live memo refresh. +// Authentication is done via Bearer token in the Authorization header, +// or via the "token" query parameter (for EventSource which cannot set headers). +func handleSSE(c echo.Context, hub *SSEHub, authenticator *auth.Authenticator) error { + // Authenticate the request. + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + // Fall back to query parameter for native EventSource support. + if token := c.QueryParam("token"); token != "" { + authHeader = "Bearer " + token + } + } + + result := authenticator.Authenticate(c.Request().Context(), authHeader) + if result == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"}) + } + + // Set SSE headers. + w := c.Response() + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering + w.WriteHeader(http.StatusOK) + + // Flush headers immediately. + if f, ok := w.Writer.(http.Flusher); ok { + f.Flush() + } + + // Subscribe to the hub. + client := hub.Subscribe() + defer hub.Unsubscribe(client) + + // Create a ticker for heartbeat pings. + heartbeat := time.NewTicker(sseHeartbeatInterval) + defer heartbeat.Stop() + + ctx := c.Request().Context() + + slog.Debug("SSE client connected") + + for { + select { + case <-ctx.Done(): + // Client disconnected. + slog.Debug("SSE client disconnected") + return nil + + case data, ok := <-client.events: + if !ok { + // Channel closed, client was unsubscribed. + return nil + } + // Write SSE event. + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + return nil + } + if f, ok := w.Writer.(http.Flusher); ok { + f.Flush() + } + + case <-heartbeat.C: + // Send a heartbeat comment to keep the connection alive. + if _, err := fmt.Fprint(w, ": heartbeat\n\n"); err != nil { + return nil + } + if f, ok := w.Writer.(http.Flusher); ok { + f.Flush() + } + } + } +} diff --git a/server/router/api/v1/sse_hub.go b/server/router/api/v1/sse_hub.go new file mode 100644 index 000000000..accebe97a --- /dev/null +++ b/server/router/api/v1/sse_hub.go @@ -0,0 +1,86 @@ +package v1 + +import ( + "encoding/json" + "sync" +) + +// MemoEventType represents the type of memo change event. +type MemoEventType string + +const ( + MemoEventCreated MemoEventType = "memo.created" + MemoEventUpdated MemoEventType = "memo.updated" + MemoEventDeleted MemoEventType = "memo.deleted" +) + +// MemoEvent represents a memo change event sent to SSE clients. +type MemoEvent struct { + Type MemoEventType `json:"type"` + // Name is the memo resource name (e.g., "memos/xxxx"). + Name string `json:"name"` +} + +// JSON returns the JSON representation of the event. +func (e *MemoEvent) JSON() []byte { + data, _ := json.Marshal(e) + return data +} + +// sseClient represents a single SSE connection. +type sseClient struct { + events chan []byte +} + +// SSEHub manages SSE client connections and broadcasts events. +// It is safe for concurrent use. +type SSEHub struct { + mu sync.RWMutex + clients map[*sseClient]struct{} +} + +// NewSSEHub creates a new SSE hub. +func NewSSEHub() *SSEHub { + return &SSEHub{ + clients: make(map[*sseClient]struct{}), + } +} + +// Subscribe registers a new client and returns it. +// The caller must call Unsubscribe when done. +func (h *SSEHub) Subscribe() *sseClient { + c := &sseClient{ + // Buffer a few events so a slow client doesn't block broadcasting. + events: make(chan []byte, 32), + } + h.mu.Lock() + h.clients[c] = struct{}{} + h.mu.Unlock() + return c +} + +// Unsubscribe removes a client and closes its channel. +func (h *SSEHub) Unsubscribe(c *sseClient) { + h.mu.Lock() + if _, ok := h.clients[c]; ok { + delete(h.clients, c) + close(c.events) + } + h.mu.Unlock() +} + +// Broadcast sends an event to all connected clients. +// Slow clients that have a full buffer will have the event dropped +// to avoid blocking the broadcaster. +func (h *SSEHub) Broadcast(event *MemoEvent) { + data := event.JSON() + h.mu.RLock() + defer h.mu.RUnlock() + for c := range h.clients { + select { + case c.events <- data: + default: + // Drop event for slow client to avoid blocking. + } + } +} diff --git a/server/router/api/v1/test/test_helper.go b/server/router/api/v1/test/test_helper.go index 779ad2eea..c3afdb38b 100644 --- a/server/router/api/v1/test/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -46,6 +46,7 @@ func NewTestService(t *testing.T) *TestService { Profile: testProfile, Store: testStore, MarkdownService: markdownService, + SSEHub: apiv1.NewSSEHub(), } return &TestService{ diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 694b5bbc3..b9cb6fb40 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -31,6 +31,7 @@ type APIV1Service struct { Profile *profile.Profile Store *store.Store MarkdownService markdown.Service + SSEHub *SSEHub // thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion thumbnailSemaphore *semaphore.Weighted @@ -45,6 +46,7 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store Profile: profile, Store: store, MarkdownService: markdownService, + SSEHub: NewSSEHub(), thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations } } diff --git a/server/server.go b/server/server.go index af09c4bcd..6ddbdd6fc 100644 --- a/server/server.go +++ b/server/server.go @@ -76,6 +76,10 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store // Create and register RSS routes (needs markdown service from apiV1Service). rss.NewRSSService(s.Profile, s.Store, apiV1Service.MarkdownService).RegisterRoutes(rootGroup) + + // Register SSE endpoint for live memo refresh. + apiv1.RegisterSSERoutes(echoServer, apiV1Service.SSEHub, s.Store, s.Secret) + // Register gRPC gateway as api v1. if err := apiV1Service.RegisterGateway(ctx, echoServer); err != nil { return nil, errors.Wrap(err, "failed to register gRPC gateway") diff --git a/web/src/hooks/useLiveMemoRefresh.ts b/web/src/hooks/useLiveMemoRefresh.ts new file mode 100644 index 000000000..e1a220d40 --- /dev/null +++ b/web/src/hooks/useLiveMemoRefresh.ts @@ -0,0 +1,152 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { useEffect, useRef } from "react"; +import { getAccessToken } from "@/auth-state"; +import { memoKeys } from "@/hooks/useMemoQueries"; +import { userKeys } from "@/hooks/useUserQueries"; + +/** + * Reconnection parameters for SSE connection. + */ +const INITIAL_RETRY_DELAY_MS = 1000; +const MAX_RETRY_DELAY_MS = 30000; +const RETRY_BACKOFF_MULTIPLIER = 2; + +/** + * useLiveMemoRefresh connects to the server's SSE endpoint and + * invalidates relevant React Query caches when memo change events + * (created, updated, deleted) are received. + * + * This enables real-time updates across all open instances of the app. + */ +export function useLiveMemoRefresh() { + const queryClient = useQueryClient(); + const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS); + const abortControllerRef = useRef(null); + + useEffect(() => { + let mounted = true; + let retryTimeout: ReturnType | null = null; + + const connect = async () => { + if (!mounted) return; + + const token = getAccessToken(); + if (!token) { + // Not logged in; retry after a delay in case the user logs in. + retryTimeout = setTimeout(connect, 5000); + return; + } + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + try { + const response = await fetch("/api/v1/sse", { + headers: { + Authorization: `Bearer ${token}`, + }, + signal: abortController.signal, + credentials: "include", + }); + + if (!response.ok || !response.body) { + throw new Error(`SSE connection failed: ${response.status}`); + } + + // Successfully connected - reset retry delay. + retryDelayRef.current = INITIAL_RETRY_DELAY_MS; + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (mounted) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + // Process complete SSE messages (separated by double newlines). + const messages = buffer.split("\n\n"); + // Keep the last incomplete chunk in the buffer. + buffer = messages.pop() || ""; + + for (const message of messages) { + if (!message.trim()) continue; + + // Parse SSE format: lines starting with "data: " contain JSON payload. + // Lines starting with ":" are comments (heartbeats). + for (const line of message.split("\n")) { + if (line.startsWith("data: ")) { + const jsonStr = line.slice(6); + try { + const event = JSON.parse(jsonStr) as { type: string; name: string }; + handleMemoEvent(event, queryClient); + } catch { + // Ignore malformed JSON. + } + } + } + } + } + } catch (err: unknown) { + if (err instanceof DOMException && err.name === "AbortError") { + // Intentional abort, don't reconnect. + return; + } + // Connection lost or failed - reconnect with backoff. + } + + // Reconnect with exponential backoff. + if (mounted) { + const delay = retryDelayRef.current; + retryDelayRef.current = Math.min(delay * RETRY_BACKOFF_MULTIPLIER, MAX_RETRY_DELAY_MS); + retryTimeout = setTimeout(connect, delay); + } + }; + + connect(); + + return () => { + mounted = false; + if (retryTimeout) { + clearTimeout(retryTimeout); + } + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + }; + }, [queryClient]); +} + +interface MemoChangeEvent { + type: string; + name: string; +} + +function handleMemoEvent(event: MemoChangeEvent, queryClient: ReturnType) { + switch (event.type) { + case "memo.created": + // Invalidate memo lists so new memos appear. + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + // Invalidate user stats (memo count changed). + queryClient.invalidateQueries({ queryKey: userKeys.stats() }); + break; + + case "memo.updated": + // Invalidate the specific memo detail cache. + queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) }); + // Invalidate memo lists to reflect updated content/ordering. + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + break; + + case "memo.deleted": + // Remove the specific memo from cache. + queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) }); + // Invalidate memo lists. + queryClient.invalidateQueries({ queryKey: memoKeys.lists() }); + // Invalidate user stats (memo count changed). + queryClient.invalidateQueries({ queryKey: userKeys.stats() }); + break; + } +} diff --git a/web/src/main.tsx b/web/src/main.tsx index 5d617b210..a87b1a14b 100644 --- a/web/src/main.tsx +++ b/web/src/main.tsx @@ -12,6 +12,7 @@ import { refreshAccessToken } from "@/connect"; import { AuthProvider, useAuth } from "@/contexts/AuthContext"; import { InstanceProvider, useInstance } from "@/contexts/InstanceContext"; import { ViewProvider } from "@/contexts/ViewContext"; +import { useLiveMemoRefresh } from "@/hooks/useLiveMemoRefresh"; import { useTokenRefreshOnFocus } from "@/hooks/useTokenRefreshOnFocus"; import { queryClient } from "@/lib/query-client"; import router from "./router"; @@ -46,6 +47,9 @@ function AppInitializer({ children }: { children: React.ReactNode }) { // Related: https://github.com/usememos/memos/issues/5589 useTokenRefreshOnFocus(refreshAccessToken, !!currentUser); + // Live refresh: listen for memo changes via SSE and invalidate caches. + useLiveMemoRefresh(); + if (!authInitialized || !instanceInitialized) { return null; } diff --git a/web/vite.config.mts b/web/vite.config.mts index 5b63cb382..1433c39f2 100644 --- a/web/vite.config.mts +++ b/web/vite.config.mts @@ -16,6 +16,12 @@ export default defineConfig({ host: "0.0.0.0", port: 3001, proxy: { + "^/api/v1/sse": { + target: devProxyServer, + xfwd: true, + // SSE requires no response buffering and longer timeout. + timeout: 0, + }, "^/api": { target: devProxyServer, xfwd: true,