memos/server/router/api/v1/v1.go

150 lines
4.9 KiB
Go

package v1
import (
"context"
"net/http"
"connectrpc.com/connect"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"golang.org/x/sync/semaphore"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
type APIV1Service struct {
v1pb.UnimplementedInstanceServiceServer
v1pb.UnimplementedAuthServiceServer
v1pb.UnimplementedUserServiceServer
v1pb.UnimplementedMemoServiceServer
v1pb.UnimplementedAttachmentServiceServer
v1pb.UnimplementedShortcutServiceServer
v1pb.UnimplementedActivityServiceServer
v1pb.UnimplementedIdentityProviderServiceServer
Secret string
Profile *profile.Profile
Store *store.Store
MarkdownService markdown.Service
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted
}
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
markdownService := markdown.NewService(
markdown.WithTagExtension(),
)
return &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
MarkdownService: markdownService,
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
}
}
// RegisterGateway registers the gRPC-Gateway and Connect handlers with the given Echo instance.
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
// Auth middleware for gRPC-Gateway - runs after routing, has access to method name.
// Uses the same PublicMethods config as the Connect AuthInterceptor.
authenticator := auth.NewAuthenticator(s.Store, s.Secret)
gatewayAuthMiddleware := func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
ctx := r.Context()
// Get the RPC method name from context (set by grpc-gateway after routing)
rpcMethod, _ := runtime.RPCMethod(ctx)
// Extract credentials from HTTP headers
var sessionCookie string
if cookie, err := r.Cookie("user_session"); err == nil {
sessionCookie = cookie.Value
}
authHeader := r.Header.Get("Authorization")
result := authenticator.Authenticate(ctx, sessionCookie, authHeader)
// Enforce authentication for non-public methods
if result == nil && !IsPublicMethod(rpcMethod) {
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
return
}
// Set user in context (may be nil for public endpoints)
if result != nil {
ctx = auth.SetUserInContext(ctx, result.User, result.SessionID, result.AccessToken)
r = r.WithContext(ctx)
}
next(w, r, pathParams)
}
}
// Create gRPC-Gateway mux with auth middleware.
gwMux := runtime.NewServeMux(
runtime.WithMiddlewares(gatewayAuthMiddleware),
)
if err := v1pb.RegisterInstanceServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAuthServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterUserServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterMemoServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterActivityServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterIdentityProviderServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
gwGroup := echoServer.Group("")
gwGroup.Use(middleware.CORS())
handler := echo.WrapHandler(gwMux)
gwGroup.Any("/api/v1/*", handler)
gwGroup.Any("/file/*", handler)
// Connect handlers for browser clients (replaces grpc-web).
logStacktraces := s.Profile.IsDev()
connectInterceptors := connect.WithInterceptors(
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
NewLoggingInterceptor(logStacktraces),
NewRecoveryInterceptor(logStacktraces),
NewAuthInterceptor(s.Store, s.Secret),
)
connectMux := http.NewServeMux()
connectHandler := NewConnectServiceHandler(s)
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors)
// Wrap with CORS for browser access
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
AllowOriginFunc: func(_ string) (bool, error) {
return true, nil
},
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
AllowHeaders: []string{"*"},
AllowCredentials: true,
})
connectGroup := echoServer.Group("", corsHandler)
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux))
return nil
}