feat: add live refresh via Server-Sent Events (SSE) with visual indicator (#5638)

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: milvasic <milvasic@users.noreply.github.com>
This commit is contained in:
milvasic 2026-03-03 15:56:12 +01:00 committed by GitHub
parent a69e405c95
commit ea0892a8b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 650 additions and 3 deletions

View File

@ -141,6 +141,12 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
}
// Broadcast live refresh event.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoCreated,
Name: memoMessage.Name,
})
return memoMessage, nil
}
@ -471,6 +477,12 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
}
// Broadcast live refresh event.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoUpdated,
Name: memoMessage.Name,
})
return memoMessage, nil
}
@ -539,6 +551,12 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
// Broadcast live refresh event.
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventMemoDeleted,
Name: request.Name,
})
return &emptypb.Empty{}, nil
}

View File

@ -97,6 +97,12 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
reactionMessage := convertReactionFromStore(reaction)
// Broadcast live refresh event (reaction belongs to a memo).
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventReactionUpserted,
Name: request.Reaction.ContentId,
})
return reactionMessage, nil
}
@ -136,6 +142,12 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
// Broadcast live refresh event (reaction belongs to a memo).
s.SSEHub.Broadcast(&SSEEvent{
Type: SSEEventReactionDeleted,
Name: reaction.ContentID,
})
return &emptypb.Empty{}, nil
}

View File

@ -0,0 +1,101 @@
package v1
import (
"fmt"
"log/slog"
"net/http"
"time"
"github.com/labstack/echo/v5"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
const (
// sseHeartbeatInterval is the interval between heartbeat pings to keep the connection alive.
sseHeartbeatInterval = 30 * time.Second
)
// RegisterSSERoutes registers the SSE endpoint on the given Echo instance.
func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) {
authenticator := auth.NewAuthenticator(storeInstance, secret)
echoServer.GET("/api/v1/sse", func(c *echo.Context) error {
return handleSSE(c, hub, authenticator)
})
}
// handleSSE handles the SSE connection for live memo refresh.
// Authentication is done via Bearer token in the Authorization header,
// or via the "token" query parameter (for EventSource which cannot set headers).
func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator) error {
// Authenticate the request.
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
// Fall back to query parameter for native EventSource support.
if token := c.QueryParam("token"); token != "" {
authHeader = "Bearer " + token
}
}
result := authenticator.Authenticate(c.Request().Context(), authHeader)
if result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
}
// Set SSE headers.
w := c.Response()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering
w.WriteHeader(http.StatusOK)
// Flush headers immediately.
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Subscribe to the hub.
client := hub.Subscribe()
defer hub.Unsubscribe(client)
// Create a ticker for heartbeat pings.
heartbeat := time.NewTicker(sseHeartbeatInterval)
defer heartbeat.Stop()
ctx := c.Request().Context()
slog.Debug("SSE client connected")
for {
select {
case <-ctx.Done():
// Client disconnected.
slog.Debug("SSE client disconnected")
return nil
case data, ok := <-client.events:
if !ok {
// Channel closed, client was unsubscribed.
return nil
}
// Write SSE event.
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
return nil
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
case <-heartbeat.C:
// Send a heartbeat comment to keep the connection alive.
if _, err := fmt.Fprint(w, ": heartbeat\n\n"); err != nil {
return nil
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
}

View File

@ -0,0 +1,98 @@
package v1
import (
"encoding/json"
"log/slog"
"sync"
)
// SSEEventType represents the type of change event.
type SSEEventType string
const (
SSEEventMemoCreated SSEEventType = "memo.created"
SSEEventMemoUpdated SSEEventType = "memo.updated"
SSEEventMemoDeleted SSEEventType = "memo.deleted"
SSEEventReactionUpserted SSEEventType = "reaction.upserted"
SSEEventReactionDeleted SSEEventType = "reaction.deleted"
)
// SSEEvent represents a change event sent to SSE clients.
type SSEEvent struct {
Type SSEEventType `json:"type"`
// Name is the affected resource name (e.g., "memos/xxxx").
// For reaction events, this is the memo resource name that the reaction belongs to.
Name string `json:"name"`
}
// JSON returns the JSON representation of the event.
// Returns nil if marshaling fails (error is logged).
func (e *SSEEvent) JSON() []byte {
data, err := json.Marshal(e)
if err != nil {
slog.Error("failed to marshal SSE event", "err", err, "event", e)
return nil
}
return data
}
// SSEClient represents a single SSE connection.
type SSEClient struct {
events chan []byte
}
// SSEHub manages SSE client connections and broadcasts events.
// It is safe for concurrent use.
type SSEHub struct {
mu sync.RWMutex
clients map[*SSEClient]struct{}
}
// NewSSEHub creates a new SSE hub.
func NewSSEHub() *SSEHub {
return &SSEHub{
clients: make(map[*SSEClient]struct{}),
}
}
// Subscribe registers a new client and returns it.
// The caller must call Unsubscribe when done.
func (h *SSEHub) Subscribe() *SSEClient {
c := &SSEClient{
// Buffer a few events so a slow client doesn't block broadcasting.
events: make(chan []byte, 32),
}
h.mu.Lock()
h.clients[c] = struct{}{}
h.mu.Unlock()
return c
}
// Unsubscribe removes a client and closes its channel.
func (h *SSEHub) Unsubscribe(c *SSEClient) {
h.mu.Lock()
if _, ok := h.clients[c]; ok {
delete(h.clients, c)
close(c.events)
}
h.mu.Unlock()
}
// Broadcast sends an event to all connected clients.
// Slow clients that have a full buffer will have the event dropped
// to avoid blocking the broadcaster.
func (h *SSEHub) Broadcast(event *SSEEvent) {
data := event.JSON()
if len(data) == 0 {
return
}
h.mu.RLock()
defer h.mu.RUnlock()
for c := range h.clients {
select {
case c.events <- data:
default:
// Drop event for slow client to avoid blocking.
}
}
}

View File

@ -0,0 +1,70 @@
package v1
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe()
require.NotNil(t, client)
require.NotNil(t, client.events)
// Unsubscribe removes the client and closes the channel.
hub.Unsubscribe(client)
// Channel should be closed.
_, ok := <-client.events
assert.False(t, ok, "channel should be closed after Unsubscribe")
}
func TestSSEHub_Broadcast(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe()
defer hub.Unsubscribe(client)
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"}
hub.Broadcast(event)
select {
case data := <-client.events:
assert.Contains(t, string(data), `"type":"memo.created"`)
assert.Contains(t, string(data), `"name":"memos/123"`)
case <-time.After(time.Second):
t.Fatal("expected to receive event within 1s")
}
}
func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
hub := NewSSEHub()
c1 := hub.Subscribe()
defer hub.Unsubscribe(c1)
c2 := hub.Subscribe()
defer hub.Unsubscribe(c2)
event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"}
hub.Broadcast(event)
for _, ch := range []chan []byte{c1.events, c2.events} {
select {
case data := <-ch:
assert.Contains(t, string(data), "memo.deleted")
assert.Contains(t, string(data), "memos/456")
case <-time.After(time.Second):
t.Fatal("expected to receive event within 1s")
}
}
}
func TestSSEEvent_JSON(t *testing.T) {
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"}
data := e.JSON()
require.NotEmpty(t, data)
assert.Contains(t, string(data), `"type":"memo.updated"`)
assert.Contains(t, string(data), `"name":"memos/789"`)
}

View File

@ -0,0 +1,86 @@
package test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
)
func TestSSEHandler_Authentication(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "sse-user")
require.NoError(t, err)
token, _, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(ts.Secret),
)
require.NoError(t, err)
e := echo.New()
apiv1.RegisterSSERoutes(e, ts.Service.SSEHub, ts.Store, ts.Secret)
t.Run("no token returns 401", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
})
t.Run("invalid token returns 401", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil)
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
})
t.Run("valid token returns 200 and stream", func(t *testing.T) {
// Use a cancellable context so we can close the SSE connection after
// confirming the headers, preventing the handler's event loop from
// blocking the test indefinitely.
reqCtx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse", nil).WithContext(reqCtx)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
done := make(chan struct{})
go func() {
defer close(done)
e.ServeHTTP(rec, req)
}()
// Cancel the context to signal client disconnect, which exits the SSE loop.
cancel()
<-done
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "text/event-stream", rec.Header().Get("Content-Type"))
})
t.Run("token in query param returns 200", func(t *testing.T) {
reqCtx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/api/v1/sse?token="+token, nil).WithContext(reqCtx)
rec := httptest.NewRecorder()
done := make(chan struct{})
go func() {
defer close(done)
e.ServeHTTP(rec, req)
}()
cancel()
<-done
require.Equal(t, http.StatusOK, rec.Code)
})
}

View File

@ -46,6 +46,7 @@ func NewTestService(t *testing.T) *TestService {
Profile: testProfile,
Store: testStore,
MarkdownService: markdownService,
SSEHub: apiv1.NewSSEHub(),
}
return &TestService{

View File

@ -31,6 +31,7 @@ type APIV1Service struct {
Profile *profile.Profile
Store *store.Store
MarkdownService markdown.Service
SSEHub *SSEHub
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted
@ -45,6 +46,7 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
Profile: profile,
Store: store,
MarkdownService: markdownService,
SSEHub: NewSSEHub(),
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
}
}
@ -115,6 +117,10 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
gwGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
// Register SSE endpoint with same CORS as rest of /api/v1.
gwGroup.GET("/api/v1/sse", func(c *echo.Context) error {
return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret))
})
handler := echo.WrapHandler(gwMux)
gwGroup.Any("/api/v1/*", handler)

View File

@ -184,7 +184,7 @@ Parses data URI to extract MIME type and base64 data.
## Dependencies
### External Packages
- `github.com/labstack/echo/v4` - HTTP router and middleware
- `github.com/labstack/echo/v5` - HTTP router and middleware
- `github.com/golang-jwt/jwt/v5` - JWT parsing and validation
- `github.com/disintegration/imaging` - Image thumbnail generation
- `golang.org/x/sync/semaphore` - Concurrency control for thumbnails

View File

@ -75,7 +75,8 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
// Create and register RSS routes (needs markdown service from apiV1Service).
rss.NewRSSService(s.Profile, s.Store, apiV1Service.MarkdownService).RegisterRoutes(rootGroup)
// Register gRPC gateway as api v1.
// Register gRPC gateway as api v1 (includes SSE endpoint on CORS-enabled group).
if err := apiV1Service.RegisterGateway(ctx, echoServer); err != nil {
return nil, errors.Wrap(err, "failed to register gRPC gateway")
}

View File

@ -8,6 +8,7 @@ import { Routes } from "@/router";
import { UserNotification_Status } from "@/types/proto/api/v1/user_service_pb";
import { useTranslate } from "@/utils/i18n";
import MemosLogo from "./MemosLogo";
import SSEStatusIndicator from "./SSEStatusIndicator";
import UserMenu from "./UserMenu";
interface NavLinkItem {
@ -114,7 +115,10 @@ const Navigation = (props: Props) => {
))}
</div>
{currentUser && (
<div className={cn("w-full flex flex-col justify-end", props.collapsed ? "items-center" : "items-start pl-3")}>
<div className={cn("w-full flex flex-col justify-end gap-1", props.collapsed ? "items-center" : "items-start pl-3")}>
<div className={cn("flex items-center", props.collapsed ? "justify-center" : "pl-1")}>
<SSEStatusIndicator />
</div>
<UserMenu collapsed={collapsed} />
</div>
)}

View File

@ -0,0 +1,36 @@
import { useSSEConnectionStatus } from "@/hooks/useLiveMemoRefresh";
import { cn } from "@/lib/utils";
import { Tooltip, TooltipContent, TooltipTrigger } from "./ui/tooltip";
/**
* A small colored dot that indicates the SSE live-refresh connection status.
* - Green = connected (live updates active)
* - Yellow/pulsing = connecting
* - Red = disconnected (updates not live)
*/
const SSEStatusIndicator = () => {
const status = useSSEConnectionStatus();
const label =
status === "connected" ? "Live updates active" : status === "connecting" ? "Connecting to live updates..." : "Live updates unavailable";
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="inline-flex items-center justify-center size-5 cursor-default" aria-label={label}>
<span
className={cn(
"block size-2 rounded-full transition-colors",
status === "connected" && "bg-green-500",
status === "connecting" && "bg-yellow-500 animate-pulse",
status === "disconnected" && "bg-red-500",
)}
/>
</span>
</TooltipTrigger>
<TooltipContent>{label}</TooltipContent>
</Tooltip>
);
};
export default SSEStatusIndicator;

View File

@ -0,0 +1,204 @@
import { useQueryClient } from "@tanstack/react-query";
import { useCallback, useEffect, useRef, useSyncExternalStore } from "react";
import { getAccessToken } from "@/auth-state";
import { useAuth } from "@/contexts/AuthContext";
import { memoKeys } from "@/hooks/useMemoQueries";
import { userKeys } from "@/hooks/useUserQueries";
/**
* Reconnection parameters for SSE connection.
*/
const INITIAL_RETRY_DELAY_MS = 1000;
const MAX_RETRY_DELAY_MS = 30000;
const RETRY_BACKOFF_MULTIPLIER = 2;
// ---------------------------------------------------------------------------
// Shared connection status store (singleton)
// ---------------------------------------------------------------------------
export type SSEConnectionStatus = "connected" | "disconnected" | "connecting";
type Listener = () => void;
let _status: SSEConnectionStatus = "disconnected";
const _listeners = new Set<Listener>();
function getSSEStatus(): SSEConnectionStatus {
return _status;
}
function setSSEStatus(s: SSEConnectionStatus) {
if (_status !== s) {
_status = s;
_listeners.forEach((l) => l());
}
}
function subscribeSSEStatus(listener: Listener): () => void {
_listeners.add(listener);
return () => _listeners.delete(listener);
}
/**
* React hook that returns the current SSE connection status.
* Re-renders the component whenever the status changes.
*/
export function useSSEConnectionStatus(): SSEConnectionStatus {
return useSyncExternalStore(subscribeSSEStatus, getSSEStatus, getSSEStatus);
}
// ---------------------------------------------------------------------------
// Main hook
// ---------------------------------------------------------------------------
/**
* useLiveMemoRefresh connects to the server's SSE endpoint and
* invalidates relevant React Query caches when change events
* (memos, reactions) are received.
*
* This enables real-time updates across all open instances of the app.
*/
export function useLiveMemoRefresh() {
const queryClient = useQueryClient();
const { currentUser } = useAuth();
const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS);
const abortControllerRef = useRef<AbortController | null>(null);
const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]);
useEffect(() => {
let mounted = true;
let retryTimeout: ReturnType<typeof setTimeout> | null = null;
const connect = async () => {
if (!mounted) return;
const token = getAccessToken();
if (!token) {
setSSEStatus("disconnected");
// Not logged in; do not retry. Effect will re-run when currentUser is set (login).
return;
}
setSSEStatus("connecting");
const abortController = new AbortController();
abortControllerRef.current = abortController;
try {
const response = await fetch("/api/v1/sse", {
headers: {
Authorization: `Bearer ${token}`,
},
signal: abortController.signal,
credentials: "include",
});
if (!response.ok || !response.body) {
throw new Error(`SSE connection failed: ${response.status}`);
}
// Successfully connected - reset retry delay.
retryDelayRef.current = INITIAL_RETRY_DELAY_MS;
setSSEStatus("connected");
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (mounted) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
// Process complete SSE messages (separated by double newlines).
const messages = buffer.split("\n\n");
// Keep the last incomplete chunk in the buffer.
buffer = messages.pop() || "";
for (const message of messages) {
if (!message.trim()) continue;
// Parse SSE format: lines starting with "data: " contain JSON payload.
// Lines starting with ":" are comments (heartbeats).
for (const line of message.split("\n")) {
if (line.startsWith("data: ")) {
const jsonStr = line.slice(6);
try {
const event = JSON.parse(jsonStr) as SSEChangeEvent;
handleEvent(event);
} catch {
// Ignore malformed JSON.
}
}
}
}
}
} catch (err: unknown) {
if (err instanceof DOMException && err.name === "AbortError") {
// Intentional abort, don't reconnect.
setSSEStatus("disconnected");
return;
}
// Connection lost or failed - reconnect with backoff.
}
setSSEStatus("disconnected");
// Reconnect with exponential backoff.
if (mounted) {
const delay = retryDelayRef.current;
retryDelayRef.current = Math.min(delay * RETRY_BACKOFF_MULTIPLIER, MAX_RETRY_DELAY_MS);
retryTimeout = setTimeout(connect, delay);
}
};
connect();
return () => {
mounted = false;
setSSEStatus("disconnected");
if (retryTimeout) {
clearTimeout(retryTimeout);
}
if (abortControllerRef.current) {
abortControllerRef.current.abort();
}
};
}, [handleEvent, currentUser]);
}
// ---------------------------------------------------------------------------
// Event handling
// ---------------------------------------------------------------------------
interface SSEChangeEvent {
type: string;
name: string;
}
function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType<typeof useQueryClient>) {
switch (event.type) {
case "memo.created":
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
break;
case "memo.updated":
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
break;
case "memo.deleted":
queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
break;
case "reaction.upserted":
case "reaction.deleted":
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
break;
}
}

View File

@ -12,6 +12,7 @@ import { refreshAccessToken } from "@/connect";
import { AuthProvider, useAuth } from "@/contexts/AuthContext";
import { InstanceProvider, useInstance } from "@/contexts/InstanceContext";
import { ViewProvider } from "@/contexts/ViewContext";
import { useLiveMemoRefresh } from "@/hooks/useLiveMemoRefresh";
import { useTokenRefreshOnFocus } from "@/hooks/useTokenRefreshOnFocus";
import { queryClient } from "@/lib/query-client";
import router from "./router";
@ -46,6 +47,9 @@ function AppInitializer({ children }: { children: React.ReactNode }) {
// Related: https://github.com/usememos/memos/issues/5589
useTokenRefreshOnFocus(refreshAccessToken, !!currentUser);
// Live refresh: listen for memo changes via SSE and invalidate caches.
useLiveMemoRefresh();
if (!authInitialized || !instanceInitialized) {
return null;
}

View File

@ -16,6 +16,12 @@ export default defineConfig({
host: "0.0.0.0",
port: 3001,
proxy: {
"^/api/v1/sse": {
target: devProxyServer,
xfwd: true,
// SSE requires no response buffering and longer timeout.
timeout: 0,
},
"^/api": {
target: devProxyServer,
xfwd: true,