package v1 import ( "context" "net/http" "connectrpc.com/connect" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "golang.org/x/sync/semaphore" "github.com/usememos/memos/internal/profile" "github.com/usememos/memos/plugin/markdown" v1pb "github.com/usememos/memos/proto/gen/api/v1" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) type APIV1Service struct { v1pb.UnimplementedInstanceServiceServer v1pb.UnimplementedAuthServiceServer v1pb.UnimplementedUserServiceServer v1pb.UnimplementedMemoServiceServer v1pb.UnimplementedAttachmentServiceServer v1pb.UnimplementedShortcutServiceServer v1pb.UnimplementedActivityServiceServer v1pb.UnimplementedIdentityProviderServiceServer Secret string Profile *profile.Profile Store *store.Store MarkdownService markdown.Service // thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion thumbnailSemaphore *semaphore.Weighted } func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service { markdownService := markdown.NewService( markdown.WithTagExtension(), ) return &APIV1Service{ Secret: secret, Profile: profile, Store: store, MarkdownService: markdownService, thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations } } // RegisterGateway registers the gRPC-Gateway and Connect handlers with the given Echo instance. func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error { // Auth middleware for gRPC-Gateway - runs after routing, has access to method name. // Uses the same PublicMethods config as the Connect AuthInterceptor. authenticator := auth.NewAuthenticator(s.Store, s.Secret) gatewayAuthMiddleware := func(next runtime.HandlerFunc) runtime.HandlerFunc { return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { ctx := r.Context() // Get the RPC method name from context (set by grpc-gateway after routing) rpcMethod, _ := runtime.RPCMethod(ctx) // Extract credentials from HTTP headers var sessionCookie string if cookie, err := r.Cookie("user_session"); err == nil { sessionCookie = cookie.Value } authHeader := r.Header.Get("Authorization") result := authenticator.Authenticate(ctx, sessionCookie, authHeader) // Enforce authentication for non-public methods if result == nil && !IsPublicMethod(rpcMethod) { http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized) return } // Set user in context (may be nil for public endpoints) if result != nil { ctx = auth.SetUserInContext(ctx, result.User, result.SessionID, result.AccessToken) r = r.WithContext(ctx) } next(w, r, pathParams) } } // Create gRPC-Gateway mux with auth middleware. gwMux := runtime.NewServeMux( runtime.WithMiddlewares(gatewayAuthMiddleware), ) if err := v1pb.RegisterInstanceServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterAuthServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterUserServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterMemoServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterActivityServiceHandlerServer(ctx, gwMux, s); err != nil { return err } if err := v1pb.RegisterIdentityProviderServiceHandlerServer(ctx, gwMux, s); err != nil { return err } gwGroup := echoServer.Group("") gwGroup.Use(middleware.CORS()) handler := echo.WrapHandler(gwMux) gwGroup.Any("/api/v1/*", handler) gwGroup.Any("/file/*", handler) // Connect handlers for browser clients (replaces grpc-web). logStacktraces := s.Profile.IsDev() connectInterceptors := connect.WithInterceptors( NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first NewLoggingInterceptor(logStacktraces), NewRecoveryInterceptor(logStacktraces), NewAuthInterceptor(s.Store, s.Secret), ) connectMux := http.NewServeMux() connectHandler := NewConnectServiceHandler(s) connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors) // Wrap with CORS for browser access corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{ AllowOriginFunc: func(_ string) (bool, error) { return true, nil }, AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, AllowHeaders: []string{"*"}, AllowCredentials: true, }) connectGroup := echoServer.Group("", corsHandler) connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux)) return nil }