memos/plugin/scheduler/middleware_test.go

147 lines
3.0 KiB
Go

package scheduler
import (
"context"
"errors"
"sync/atomic"
"testing"
)
func TestMiddlewareChaining(t *testing.T) {
var order []string
mw1 := func(next JobHandler) JobHandler {
return func(ctx context.Context) error {
order = append(order, "before-1")
err := next(ctx)
order = append(order, "after-1")
return err
}
}
mw2 := func(next JobHandler) JobHandler {
return func(ctx context.Context) error {
order = append(order, "before-2")
err := next(ctx)
order = append(order, "after-2")
return err
}
}
handler := func(_ context.Context) error {
order = append(order, "handler")
return nil
}
chain := Chain(mw1, mw2)
wrapped := chain(handler)
if err := wrapped(context.Background()); err != nil {
t.Fatalf("wrapped handler failed: %v", err)
}
expected := []string{"before-1", "before-2", "handler", "after-2", "after-1"}
if len(order) != len(expected) {
t.Fatalf("expected %d calls, got %d", len(expected), len(order))
}
for i, want := range expected {
if order[i] != want {
t.Errorf("order[%d] = %q, want %q", i, order[i], want)
}
}
}
func TestRecoveryMiddleware(t *testing.T) {
var panicRecovered atomic.Bool
onPanic := func(_ string, _ interface{}) {
panicRecovered.Store(true)
}
handler := func(_ context.Context) error {
panic("simulated panic")
}
recovery := Recovery(onPanic)
wrapped := recovery(handler)
// Should not panic, error should be returned
err := wrapped(withJobName(context.Background(), "test-job"))
if err == nil {
t.Error("expected error from recovered panic")
}
if !panicRecovered.Load() {
t.Error("panic handler was not called")
}
}
func TestLoggingMiddleware(t *testing.T) {
var loggedStart, loggedEnd atomic.Bool
var loggedError atomic.Bool
logger := &testLogger{
onInfo: func(msg string, _ ...interface{}) {
if msg == "Job started" {
loggedStart.Store(true)
} else if msg == "Job completed" {
loggedEnd.Store(true)
}
},
onError: func(msg string, _ ...interface{}) {
if msg == "Job failed" {
loggedError.Store(true)
}
},
}
// Test successful execution
handler := func(_ context.Context) error {
return nil
}
logging := Logging(logger)
wrapped := logging(handler)
if err := wrapped(withJobName(context.Background(), "test-job")); err != nil {
t.Fatalf("handler failed: %v", err)
}
if !loggedStart.Load() {
t.Error("start was not logged")
}
if !loggedEnd.Load() {
t.Error("end was not logged")
}
// Test error handling
handlerErr := func(_ context.Context) error {
return errors.New("job error")
}
wrappedErr := logging(handlerErr)
_ = wrappedErr(withJobName(context.Background(), "test-job-error"))
if !loggedError.Load() {
t.Error("error was not logged")
}
}
type testLogger struct {
onInfo func(msg string, args ...interface{})
onError func(msg string, args ...interface{})
}
func (l *testLogger) Info(msg string, args ...interface{}) {
if l.onInfo != nil {
l.onInfo(msg, args...)
}
}
func (l *testLogger) Error(msg string, args ...interface{}) {
if l.onError != nil {
l.onError(msg, args...)
}
}