diff --git a/server/router/api/v1/acl.go b/server/router/api/v1/acl.go index a6acb8341..9594d4795 100644 --- a/server/router/api/v1/acl.go +++ b/server/router/api/v1/acl.go @@ -53,6 +53,10 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { + // If metadata is missing, only allow public methods + if IsPublicMethod(serverInfo.FullMethod) { + return handler(ctx, request) + } return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") } @@ -60,7 +64,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re if sessionCookie := extractSessionCookieFromMetadata(md); sessionCookie != "" { user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie) if err == nil && user != nil { - _, sessionID, _ := auth.ParseSessionCookieValue(sessionCookie) + _, sessionID, err := auth.ParseSessionCookieValue(sessionCookie) + if err != nil { + // This should not happen since AuthenticateBySession already validated the cookie + // but handle it gracefully anyway + sessionID = "" + } ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, sessionID, "", IsAdminOnlyMethod) if err != nil { return nil, toGRPCError(err, codes.PermissionDenied) diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 92e4dafca..3db5296b4 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -10,7 +10,6 @@ import ( "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -237,10 +236,8 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim if err != nil { return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err) } - if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - "Set-Cookie": sessionCookie, - })); err != nil { - return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) + if err := SetResponseHeader(ctx, "Set-Cookie", sessionCookie); err != nil { + return status.Errorf(codes.Internal, "failed to set response header, error: %v", err) } return nil @@ -284,11 +281,9 @@ func (s *APIV1Service) clearAuthCookies(ctx context.Context) error { return errors.Wrap(err, "failed to build session cookie") } - // Set both cookies in the response - if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - "Set-Cookie": sessionCookie, - })); err != nil { - return errors.Wrap(err, "failed to set grpc header") + // Set cookie in the response + if err := SetResponseHeader(ctx, "Set-Cookie", sessionCookie); err != nil { + return errors.Wrap(err, "failed to set response header") } return nil } @@ -305,15 +300,18 @@ func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123)) } - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return "", errors.New("failed to get metadata from context") + // Try to determine if the request is HTTPS by checking the origin header + // Default to non-HTTPS (Strict SameSite) if metadata is not available + isHTTPS := false + if md, ok := metadata.FromIncomingContext(ctx); ok { + for _, v := range md.Get("origin") { + if strings.HasPrefix(v, "https://") { + isHTTPS = true + break + } + } } - var origin string - for _, v := range md.Get("origin") { - origin = v - } - isHTTPS := strings.HasPrefix(origin, "https://") + if isHTTPS { attrs = append(attrs, "SameSite=None") attrs = append(attrs, "Secure") diff --git a/server/router/api/v1/connect_interceptors.go b/server/router/api/v1/connect_interceptors.go index a7acb2c24..d4fda4c57 100644 --- a/server/router/api/v1/connect_interceptors.go +++ b/server/router/api/v1/connect_interceptors.go @@ -150,7 +150,12 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { if sessionCookie := auth.ExtractSessionCookieFromHeader(header.Get("Cookie")); sessionCookie != "" { user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie) if err == nil && user != nil { - _, sessionID, _ := auth.ParseSessionCookieValue(sessionCookie) + _, sessionID, err := auth.ParseSessionCookieValue(sessionCookie) + if err != nil { + // This should not happen since AuthenticateBySession already validated the cookie + // but handle it gracefully anyway + sessionID = "" + } ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, sessionID, "", IsAdminOnlyMethod) if err != nil { return nil, convertAuthError(err) diff --git a/server/router/api/v1/connect_services.go b/server/router/api/v1/connect_services.go index 281b9632f..ea89603f2 100644 --- a/server/router/api/v1/connect_services.go +++ b/server/router/api/v1/connect_services.go @@ -40,29 +40,27 @@ func (s *ConnectServiceHandler) UpdateInstanceSetting(ctx context.Context, req * } // AuthService +// +// Auth service methods need special handling for response headers (cookies). +// We use withHeaderCarrier helper to inject a header carrier into the context, +// which allows the service to set headers in a protocol-agnostic way. func (s *ConnectServiceHandler) GetCurrentSession(ctx context.Context, req *connect.Request[v1pb.GetCurrentSessionRequest]) (*connect.Response[v1pb.GetCurrentSessionResponse], error) { - resp, err := s.APIV1Service.GetCurrentSession(ctx, req.Msg) - if err != nil { - return nil, convertGRPCError(err) - } - return connect.NewResponse(resp), nil + return withHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.GetCurrentSessionResponse, error) { + return s.APIV1Service.GetCurrentSession(ctx, req.Msg) + }) } func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[v1pb.CreateSessionRequest]) (*connect.Response[v1pb.CreateSessionResponse], error) { - resp, err := s.APIV1Service.CreateSession(ctx, req.Msg) - if err != nil { - return nil, convertGRPCError(err) - } - return connect.NewResponse(resp), nil + return withHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) { + return s.APIV1Service.CreateSession(ctx, req.Msg) + }) } func (s *ConnectServiceHandler) DeleteSession(ctx context.Context, req *connect.Request[v1pb.DeleteSessionRequest]) (*connect.Response[emptypb.Empty], error) { - resp, err := s.APIV1Service.DeleteSession(ctx, req.Msg) - if err != nil { - return nil, convertGRPCError(err) - } - return connect.NewResponse(resp), nil + return withHeaderCarrier(ctx, func(ctx context.Context) (*emptypb.Empty, error) { + return s.APIV1Service.DeleteSession(ctx, req.Msg) + }) } // UserService diff --git a/server/router/api/v1/header_carrier.go b/server/router/api/v1/header_carrier.go new file mode 100644 index 000000000..2cc5b1b9e --- /dev/null +++ b/server/router/api/v1/header_carrier.go @@ -0,0 +1,122 @@ +package v1 + +import ( + "context" + + "connectrpc.com/connect" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" +) + +// headerCarrierKey is the context key for storing headers to be set in the response. +type headerCarrierKey struct{} + +// HeaderCarrier stores headers that need to be set in the response. +// +// Problem: The codebase supports two protocols simultaneously: +// - Native gRPC: Uses grpc.SetHeader() to set response headers +// - Connect-RPC: Uses connect.Response.Header().Set() to set response headers +// +// Solution: HeaderCarrier provides a protocol-agnostic way to set headers. +// - Service methods call SetResponseHeader() regardless of protocol +// - For gRPC requests: SetResponseHeader uses grpc.SetHeader directly +// - For Connect requests: SetResponseHeader stores headers in HeaderCarrier +// - Connect wrappers extract headers from HeaderCarrier and apply to response +// +// This allows service methods to work with both protocols without knowing which one is being used. +type HeaderCarrier struct { + headers map[string]string +} + +// newHeaderCarrier creates a new header carrier. +func newHeaderCarrier() *HeaderCarrier { + return &HeaderCarrier{ + headers: make(map[string]string), + } +} + +// Set adds a header to the carrier. +func (h *HeaderCarrier) Set(key, value string) { + h.headers[key] = value +} + +// Get retrieves a header from the carrier. +func (h *HeaderCarrier) Get(key string) string { + return h.headers[key] +} + +// All returns all headers. +func (h *HeaderCarrier) All() map[string]string { + return h.headers +} + +// WithHeaderCarrier adds a header carrier to the context. +func WithHeaderCarrier(ctx context.Context) context.Context { + return context.WithValue(ctx, headerCarrierKey{}, newHeaderCarrier()) +} + +// GetHeaderCarrier retrieves the header carrier from the context. +// Returns nil if no carrier is present. +func GetHeaderCarrier(ctx context.Context) *HeaderCarrier { + if carrier, ok := ctx.Value(headerCarrierKey{}).(*HeaderCarrier); ok { + return carrier + } + return nil +} + +// SetResponseHeader sets a header in the response. +// +// This function works for both gRPC and Connect protocols: +// - For gRPC: Uses grpc.SetHeader to set headers in gRPC metadata +// - For Connect: Stores in HeaderCarrier for Connect wrapper to apply later +// +// The protocol is automatically detected based on whether a HeaderCarrier +// exists in the context (injected by Connect wrappers). +func SetResponseHeader(ctx context.Context, key, value string) error { + // Try Connect first (check if we have a header carrier) + if carrier := GetHeaderCarrier(ctx); carrier != nil { + carrier.Set(key, value) + return nil + } + + // Fall back to gRPC + return grpc.SetHeader(ctx, metadata.New(map[string]string{ + key: value, + })) +} + +// withHeaderCarrier is a helper for Connect service wrappers that need to set response headers. +// +// It injects a HeaderCarrier into the context, calls the service method, +// and applies any headers from the carrier to the Connect response. +// +// Usage in Connect wrappers: +// +// func (s *ConnectServiceHandler) CreateSession(ctx context.Context, req *connect.Request[...]) (*connect.Response[...], error) { +// return withHeaderCarrier(ctx, func(ctx context.Context) (*v1pb.CreateSessionResponse, error) { +// return s.APIV1Service.CreateSession(ctx, req.Msg) +// }) +// } +func withHeaderCarrier[T proto.Message](ctx context.Context, fn func(context.Context) (T, error)) (*connect.Response[T], error) { + // Inject header carrier for Connect protocol + ctx = WithHeaderCarrier(ctx) + + // Call the service method + resp, err := fn(ctx) + if err != nil { + return nil, convertGRPCError(err) + } + + // Create Connect response + connectResp := connect.NewResponse(resp) + + // Apply any headers set via the header carrier + if carrier := GetHeaderCarrier(ctx); carrier != nil { + for key, value := range carrier.All() { + connectResp.Header().Set(key, value) + } + } + + return connectResp, nil +}