mirror of https://github.com/usememos/memos.git
204 lines
6.3 KiB
Go
204 lines
6.3 KiB
Go
package v1
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"runtime/debug"
|
|
|
|
"connectrpc.com/connect"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/usememos/memos/server/auth"
|
|
"github.com/usememos/memos/store"
|
|
)
|
|
|
|
// LoggingInterceptor logs Connect RPC requests with appropriate log levels.
|
|
//
|
|
// Log levels:
|
|
// - INFO: Successful requests and expected client errors (not found, permission denied, etc.)
|
|
// - ERROR: Server errors (internal, unavailable, etc.)
|
|
type LoggingInterceptor struct {
|
|
logStacktrace bool
|
|
}
|
|
|
|
// NewLoggingInterceptor creates a new logging interceptor.
|
|
func NewLoggingInterceptor(logStacktrace bool) *LoggingInterceptor {
|
|
return &LoggingInterceptor{logStacktrace: logStacktrace}
|
|
}
|
|
|
|
func (in *LoggingInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
|
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
|
resp, err := next(ctx, req)
|
|
in.log(req.Spec().Procedure, err)
|
|
return resp, err
|
|
}
|
|
}
|
|
|
|
func (*LoggingInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
|
return next // No-op for server-side interceptor
|
|
}
|
|
|
|
func (*LoggingInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
|
return next // Streaming not used in this service
|
|
}
|
|
|
|
func (in *LoggingInterceptor) log(procedure string, err error) {
|
|
level, msg := in.classifyError(err)
|
|
attrs := []slog.Attr{slog.String("method", procedure)}
|
|
if err != nil {
|
|
attrs = append(attrs, slog.String("error", err.Error()))
|
|
if in.logStacktrace {
|
|
attrs = append(attrs, slog.String("stacktrace", fmt.Sprintf("%+v", err)))
|
|
}
|
|
}
|
|
slog.LogAttrs(context.Background(), level, msg, attrs...)
|
|
}
|
|
|
|
func (*LoggingInterceptor) classifyError(err error) (slog.Level, string) {
|
|
if err == nil {
|
|
return slog.LevelInfo, "OK"
|
|
}
|
|
|
|
var connectErr *connect.Error
|
|
if !errors.As(err, &connectErr) {
|
|
return slog.LevelError, "unknown error"
|
|
}
|
|
|
|
// Client errors (expected, log at INFO)
|
|
switch connectErr.Code() {
|
|
case connect.CodeCanceled,
|
|
connect.CodeInvalidArgument,
|
|
connect.CodeNotFound,
|
|
connect.CodeAlreadyExists,
|
|
connect.CodePermissionDenied,
|
|
connect.CodeUnauthenticated,
|
|
connect.CodeResourceExhausted,
|
|
connect.CodeFailedPrecondition,
|
|
connect.CodeAborted,
|
|
connect.CodeOutOfRange:
|
|
return slog.LevelInfo, "client error"
|
|
default:
|
|
// Server errors
|
|
return slog.LevelError, "server error"
|
|
}
|
|
}
|
|
|
|
// RecoveryInterceptor recovers from panics in Connect handlers and returns an internal error.
|
|
type RecoveryInterceptor struct {
|
|
logStacktrace bool
|
|
}
|
|
|
|
// NewRecoveryInterceptor creates a new recovery interceptor.
|
|
func NewRecoveryInterceptor(logStacktrace bool) *RecoveryInterceptor {
|
|
return &RecoveryInterceptor{logStacktrace: logStacktrace}
|
|
}
|
|
|
|
func (in *RecoveryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
|
return func(ctx context.Context, req connect.AnyRequest) (resp connect.AnyResponse, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
in.logPanic(req.Spec().Procedure, r)
|
|
err = connect.NewError(connect.CodeInternal, errors.New("internal server error"))
|
|
}
|
|
}()
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
func (*RecoveryInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
|
return next
|
|
}
|
|
|
|
func (*RecoveryInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
|
return next
|
|
}
|
|
|
|
func (in *RecoveryInterceptor) logPanic(procedure string, panicValue any) {
|
|
attrs := []slog.Attr{
|
|
slog.String("method", procedure),
|
|
slog.Any("panic", panicValue),
|
|
}
|
|
if in.logStacktrace {
|
|
attrs = append(attrs, slog.String("stacktrace", string(debug.Stack())))
|
|
}
|
|
slog.LogAttrs(context.Background(), slog.LevelError, "panic recovered in Connect handler", attrs...)
|
|
}
|
|
|
|
// AuthInterceptor handles authentication for Connect handlers.
|
|
//
|
|
// It reuses the same authentication logic as GRPCAuthInterceptor by delegating
|
|
// to a shared Authenticator instance. This ensures consistent authentication
|
|
// behavior across both gRPC and Connect protocols.
|
|
type AuthInterceptor struct {
|
|
authenticator *auth.Authenticator
|
|
}
|
|
|
|
// NewAuthInterceptor creates a new auth interceptor.
|
|
func NewAuthInterceptor(store *store.Store, secret string) *AuthInterceptor {
|
|
return &AuthInterceptor{
|
|
authenticator: auth.NewAuthenticator(store, secret),
|
|
}
|
|
}
|
|
|
|
func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
|
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
|
procedure := req.Spec().Procedure
|
|
header := req.Header()
|
|
|
|
// Try session cookie authentication first
|
|
if sessionCookie := auth.ExtractSessionCookieFromHeader(header.Get("Cookie")); sessionCookie != "" {
|
|
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
|
|
if err == nil && user != nil {
|
|
_, sessionID, _ := auth.ParseSessionCookieValue(sessionCookie)
|
|
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, sessionID, "", IsAdminOnlyMethod)
|
|
if err != nil {
|
|
return nil, convertAuthError(err)
|
|
}
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
// Try JWT token authentication
|
|
if accessToken := auth.ExtractBearerToken(header.Get("Authorization")); accessToken != "" {
|
|
user, err := in.authenticator.AuthenticateByJWT(ctx, accessToken)
|
|
if err == nil && user != nil {
|
|
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, "", accessToken, IsAdminOnlyMethod)
|
|
if err != nil {
|
|
return nil, convertAuthError(err)
|
|
}
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
// Allow public methods without authentication
|
|
if IsPublicMethod(procedure) {
|
|
return next(ctx, req)
|
|
}
|
|
|
|
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
|
|
}
|
|
}
|
|
|
|
func (*AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
|
return next
|
|
}
|
|
|
|
func (*AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
|
return next
|
|
}
|
|
|
|
// convertAuthError converts authentication/authorization errors to Connect errors.
|
|
func convertAuthError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// Check if it's already a Connect error
|
|
var connectErr *connect.Error
|
|
if errors.As(err, &connectErr) {
|
|
return err
|
|
}
|
|
// Default to permission denied for auth errors
|
|
return connect.NewError(connect.CodePermissionDenied, err)
|
|
}
|