mirror of https://github.com/usememos/memos.git
fix(backend): implement protocol-agnostic header setting for dual gRPC/Connect-RPC support
Problem: The codebase supports both native gRPC and Connect-RPC protocols, but auth service was using grpc.SetHeader() which only works for native gRPC. This caused "failed to set grpc header" errors when using Connect-RPC clients (browsers using nice-grpc-web). Solution: - Created HeaderCarrier pattern for protocol-agnostic header setting - HeaderCarrier stores headers in context for Connect-RPC requests - Falls back to grpc.SetHeader for native gRPC requests - Updated auth service to use SetResponseHeader() instead of grpc.SetHeader() - Refactored Connect wrappers to use withHeaderCarrier() helper to eliminate code duplication Additional fixes: - Allow public methods when gRPC metadata is missing in ACL interceptor - Properly handle ParseSessionCookieValue errors instead of ignoring them - Fix buildSessionCookie to gracefully handle missing metadata Files changed: - server/router/api/v1/header_carrier.go: New protocol-agnostic header carrier - server/router/api/v1/auth_service.go: Use SetResponseHeader, handle missing metadata - server/router/api/v1/connect_services.go: Use withHeaderCarrier helper - server/router/api/v1/acl.go: Allow public methods without metadata - server/router/api/v1/connect_interceptors.go: Handle ParseSessionCookieValue errors 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
8a7e00886d
commit
3d893a7394
|
|
@ -53,6 +53,10 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep
|
||||||
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
// If metadata is missing, only allow public methods
|
||||||
|
if IsPublicMethod(serverInfo.FullMethod) {
|
||||||
|
return handler(ctx, request)
|
||||||
|
}
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
|
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -60,7 +64,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
|
||||||
if sessionCookie := extractSessionCookieFromMetadata(md); sessionCookie != "" {
|
if sessionCookie := extractSessionCookieFromMetadata(md); sessionCookie != "" {
|
||||||
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
|
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
|
||||||
if err == nil && user != nil {
|
if err == nil && user != nil {
|
||||||
_, sessionID, _ := auth.ParseSessionCookieValue(sessionCookie)
|
_, sessionID, err := auth.ParseSessionCookieValue(sessionCookie)
|
||||||
|
if err != nil {
|
||||||
|
// This should not happen since AuthenticateBySession already validated the cookie
|
||||||
|
// but handle it gracefully anyway
|
||||||
|
sessionID = ""
|
||||||
|
}
|
||||||
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, sessionID, "", IsAdminOnlyMethod)
|
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, sessionID, "", IsAdminOnlyMethod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, toGRPCError(err, codes.PermissionDenied)
|
return nil, toGRPCError(err, codes.PermissionDenied)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
@ -237,10 +236,8 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
|
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
|
||||||
}
|
}
|
||||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
if err := SetResponseHeader(ctx, "Set-Cookie", sessionCookie); err != nil {
|
||||||
"Set-Cookie": sessionCookie,
|
return status.Errorf(codes.Internal, "failed to set response header, error: %v", err)
|
||||||
})); err != nil {
|
|
||||||
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -284,11 +281,9 @@ func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
|
||||||
return errors.Wrap(err, "failed to build session cookie")
|
return errors.Wrap(err, "failed to build session cookie")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set both cookies in the response
|
// Set cookie in the response
|
||||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
if err := SetResponseHeader(ctx, "Set-Cookie", sessionCookie); err != nil {
|
||||||
"Set-Cookie": sessionCookie,
|
return errors.Wrap(err, "failed to set response header")
|
||||||
})); err != nil {
|
|
||||||
return errors.Wrap(err, "failed to set grpc header")
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -305,15 +300,18 @@ func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue
|
||||||
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
|
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
|
||||||
}
|
}
|
||||||
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
// Try to determine if the request is HTTPS by checking the origin header
|
||||||
if !ok {
|
// Default to non-HTTPS (Strict SameSite) if metadata is not available
|
||||||
return "", errors.New("failed to get metadata from context")
|
isHTTPS := false
|
||||||
|
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||||
|
for _, v := range md.Get("origin") {
|
||||||
|
if strings.HasPrefix(v, "https://") {
|
||||||
|
isHTTPS = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var origin string
|
|
||||||
for _, v := range md.Get("origin") {
|
|
||||||
origin = v
|
|
||||||
}
|
|
||||||
isHTTPS := strings.HasPrefix(origin, "https://")
|
|
||||||
if isHTTPS {
|
if isHTTPS {
|
||||||
attrs = append(attrs, "SameSite=None")
|
attrs = append(attrs, "SameSite=None")
|
||||||
attrs = append(attrs, "Secure")
|
attrs = append(attrs, "Secure")
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,12 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||||
if sessionCookie := auth.ExtractSessionCookieFromHeader(header.Get("Cookie")); sessionCookie != "" {
|
if sessionCookie := auth.ExtractSessionCookieFromHeader(header.Get("Cookie")); sessionCookie != "" {
|
||||||
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
|
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
|
||||||
if err == nil && user != nil {
|
if err == nil && user != nil {
|
||||||
_, sessionID, _ := auth.ParseSessionCookieValue(sessionCookie)
|
_, sessionID, err := auth.ParseSessionCookieValue(sessionCookie)
|
||||||
|
if err != nil {
|
||||||
|
// This should not happen since AuthenticateBySession already validated the cookie
|
||||||
|
// but handle it gracefully anyway
|
||||||
|
sessionID = ""
|
||||||
|
}
|
||||||
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, sessionID, "", IsAdminOnlyMethod)
|
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, sessionID, "", IsAdminOnlyMethod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, convertAuthError(err)
|
return nil, convertAuthError(err)
|
||||||
|
|
|
||||||
|
|
@ -40,29 +40,27 @@ func (s *ConnectServiceHandler) UpdateInstanceSetting(ctx context.Context, req *
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthService
|
// AuthService
|
||||||
|
//
|
||||||
|
// Auth service methods need special handling for response headers (cookies).
|
||||||
|
// We use withHeaderCarrier helper to inject a header carrier into the context,
|
||||||
|
// which allows the service to set headers in a protocol-agnostic way.
|
||||||
|
|
||||||
func (s *ConnectServiceHandler) GetCurrentSession(ctx context.Context, req *connect.Request[v1pb.GetCurrentSessionRequest]) (*connect.Response[v1pb.GetCurrentSessionResponse], error) {
|
func (s *ConnectServiceHandler) GetCurrentSession(ctx context.Context, req *connect.Request[v1pb.GetCurrentSessionRequest]) (*connect.Response[v1pb.GetCurrentSessionResponse], error) {
|
||||||
resp, err := s.APIV1Service.GetCurrentSession(ctx, req.Msg)
|
return withHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.GetCurrentSessionResponse, error) {
|
||||||
if err != nil {
|
return s.APIV1Service.GetCurrentSession(ctx, req.Msg)
|
||||||
return nil, convertGRPCError(err)
|
})
|
||||||
}
|
|
||||||
return connect.NewResponse(resp), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[v1pb.CreateSessionRequest]) (*connect.Response[v1pb.CreateSessionResponse], error) {
|
func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[v1pb.CreateSessionRequest]) (*connect.Response[v1pb.CreateSessionResponse], error) {
|
||||||
resp, err := s.APIV1Service.CreateSession(ctx, req.Msg)
|
return withHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) {
|
||||||
if err != nil {
|
return s.APIV1Service.CreateSession(ctx, req.Msg)
|
||||||
return nil, convertGRPCError(err)
|
})
|
||||||
}
|
|
||||||
return connect.NewResponse(resp), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ConnectServiceHandler) DeleteSession(ctx context.Context, req *connect.Request[v1pb.DeleteSessionRequest]) (*connect.Response[emptypb.Empty], error) {
|
func (s *ConnectServiceHandler) DeleteSession(ctx context.Context, req *connect.Request[v1pb.DeleteSessionRequest]) (*connect.Response[emptypb.Empty], error) {
|
||||||
resp, err := s.APIV1Service.DeleteSession(ctx, req.Msg)
|
return withHeaderCarrier(ctx, func(ctx context.Context) (*emptypb.Empty, error) {
|
||||||
if err != nil {
|
return s.APIV1Service.DeleteSession(ctx, req.Msg)
|
||||||
return nil, convertGRPCError(err)
|
})
|
||||||
}
|
|
||||||
return connect.NewResponse(resp), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserService
|
// UserService
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
package v1
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"connectrpc.com/connect"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// headerCarrierKey is the context key for storing headers to be set in the response.
|
||||||
|
type headerCarrierKey struct{}
|
||||||
|
|
||||||
|
// HeaderCarrier stores headers that need to be set in the response.
|
||||||
|
//
|
||||||
|
// Problem: The codebase supports two protocols simultaneously:
|
||||||
|
// - Native gRPC: Uses grpc.SetHeader() to set response headers
|
||||||
|
// - Connect-RPC: Uses connect.Response.Header().Set() to set response headers
|
||||||
|
//
|
||||||
|
// Solution: HeaderCarrier provides a protocol-agnostic way to set headers.
|
||||||
|
// - Service methods call SetResponseHeader() regardless of protocol
|
||||||
|
// - For gRPC requests: SetResponseHeader uses grpc.SetHeader directly
|
||||||
|
// - For Connect requests: SetResponseHeader stores headers in HeaderCarrier
|
||||||
|
// - Connect wrappers extract headers from HeaderCarrier and apply to response
|
||||||
|
//
|
||||||
|
// This allows service methods to work with both protocols without knowing which one is being used.
|
||||||
|
type HeaderCarrier struct {
|
||||||
|
headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHeaderCarrier creates a new header carrier.
|
||||||
|
func newHeaderCarrier() *HeaderCarrier {
|
||||||
|
return &HeaderCarrier{
|
||||||
|
headers: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set adds a header to the carrier.
|
||||||
|
func (h *HeaderCarrier) Set(key, value string) {
|
||||||
|
h.headers[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a header from the carrier.
|
||||||
|
func (h *HeaderCarrier) Get(key string) string {
|
||||||
|
return h.headers[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns all headers.
|
||||||
|
func (h *HeaderCarrier) All() map[string]string {
|
||||||
|
return h.headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeaderCarrier adds a header carrier to the context.
|
||||||
|
func WithHeaderCarrier(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, headerCarrierKey{}, newHeaderCarrier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHeaderCarrier retrieves the header carrier from the context.
|
||||||
|
// Returns nil if no carrier is present.
|
||||||
|
func GetHeaderCarrier(ctx context.Context) *HeaderCarrier {
|
||||||
|
if carrier, ok := ctx.Value(headerCarrierKey{}).(*HeaderCarrier); ok {
|
||||||
|
return carrier
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseHeader sets a header in the response.
|
||||||
|
//
|
||||||
|
// This function works for both gRPC and Connect protocols:
|
||||||
|
// - For gRPC: Uses grpc.SetHeader to set headers in gRPC metadata
|
||||||
|
// - For Connect: Stores in HeaderCarrier for Connect wrapper to apply later
|
||||||
|
//
|
||||||
|
// The protocol is automatically detected based on whether a HeaderCarrier
|
||||||
|
// exists in the context (injected by Connect wrappers).
|
||||||
|
func SetResponseHeader(ctx context.Context, key, value string) error {
|
||||||
|
// Try Connect first (check if we have a header carrier)
|
||||||
|
if carrier := GetHeaderCarrier(ctx); carrier != nil {
|
||||||
|
carrier.Set(key, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to gRPC
|
||||||
|
return grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||||
|
key: value,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// withHeaderCarrier is a helper for Connect service wrappers that need to set response headers.
|
||||||
|
//
|
||||||
|
// It injects a HeaderCarrier into the context, calls the service method,
|
||||||
|
// and applies any headers from the carrier to the Connect response.
|
||||||
|
//
|
||||||
|
// Usage in Connect wrappers:
|
||||||
|
//
|
||||||
|
// func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[...]) (*connect.Response[...], error) {
|
||||||
|
// return withHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) {
|
||||||
|
// return s.APIV1Service.CreateSession(ctx, req.Msg)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
func withHeaderCarrier[T proto.Message](ctx context.Context, fn func(context.Context) (T, error)) (*connect.Response[T], error) {
|
||||||
|
// Inject header carrier for Connect protocol
|
||||||
|
ctx = WithHeaderCarrier(ctx)
|
||||||
|
|
||||||
|
// Call the service method
|
||||||
|
resp, err := fn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, convertGRPCError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Connect response
|
||||||
|
connectResp := connect.NewResponse(resp)
|
||||||
|
|
||||||
|
// Apply any headers set via the header carrier
|
||||||
|
if carrier := GetHeaderCarrier(ctx); carrier != nil {
|
||||||
|
for key, value := range carrier.All() {
|
||||||
|
connectResp.Header().Set(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return connectResp, nil
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue