mirror of https://github.com/usememos/memos.git
147 lines
3.0 KiB
Go
147 lines
3.0 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync/atomic"
|
|
"testing"
|
|
)
|
|
|
|
func TestMiddlewareChaining(t *testing.T) {
|
|
var order []string
|
|
|
|
mw1 := func(next JobHandler) JobHandler {
|
|
return func(ctx context.Context) error {
|
|
order = append(order, "before-1")
|
|
err := next(ctx)
|
|
order = append(order, "after-1")
|
|
return err
|
|
}
|
|
}
|
|
|
|
mw2 := func(next JobHandler) JobHandler {
|
|
return func(ctx context.Context) error {
|
|
order = append(order, "before-2")
|
|
err := next(ctx)
|
|
order = append(order, "after-2")
|
|
return err
|
|
}
|
|
}
|
|
|
|
handler := func(_ context.Context) error {
|
|
order = append(order, "handler")
|
|
return nil
|
|
}
|
|
|
|
chain := Chain(mw1, mw2)
|
|
wrapped := chain(handler)
|
|
|
|
if err := wrapped(context.Background()); err != nil {
|
|
t.Fatalf("wrapped handler failed: %v", err)
|
|
}
|
|
|
|
expected := []string{"before-1", "before-2", "handler", "after-2", "after-1"}
|
|
if len(order) != len(expected) {
|
|
t.Fatalf("expected %d calls, got %d", len(expected), len(order))
|
|
}
|
|
|
|
for i, want := range expected {
|
|
if order[i] != want {
|
|
t.Errorf("order[%d] = %q, want %q", i, order[i], want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRecoveryMiddleware(t *testing.T) {
|
|
var panicRecovered atomic.Bool
|
|
|
|
onPanic := func(_ string, _ interface{}) {
|
|
panicRecovered.Store(true)
|
|
}
|
|
|
|
handler := func(_ context.Context) error {
|
|
panic("simulated panic")
|
|
}
|
|
|
|
recovery := Recovery(onPanic)
|
|
wrapped := recovery(handler)
|
|
|
|
// Should not panic, error should be returned
|
|
err := wrapped(withJobName(context.Background(), "test-job"))
|
|
if err == nil {
|
|
t.Error("expected error from recovered panic")
|
|
}
|
|
|
|
if !panicRecovered.Load() {
|
|
t.Error("panic handler was not called")
|
|
}
|
|
}
|
|
|
|
func TestLoggingMiddleware(t *testing.T) {
|
|
var loggedStart, loggedEnd atomic.Bool
|
|
var loggedError atomic.Bool
|
|
|
|
logger := &testLogger{
|
|
onInfo: func(msg string, _ ...interface{}) {
|
|
if msg == "Job started" {
|
|
loggedStart.Store(true)
|
|
} else if msg == "Job completed" {
|
|
loggedEnd.Store(true)
|
|
}
|
|
},
|
|
onError: func(msg string, _ ...interface{}) {
|
|
if msg == "Job failed" {
|
|
loggedError.Store(true)
|
|
}
|
|
},
|
|
}
|
|
|
|
// Test successful execution
|
|
handler := func(_ context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
logging := Logging(logger)
|
|
wrapped := logging(handler)
|
|
|
|
if err := wrapped(withJobName(context.Background(), "test-job")); err != nil {
|
|
t.Fatalf("handler failed: %v", err)
|
|
}
|
|
|
|
if !loggedStart.Load() {
|
|
t.Error("start was not logged")
|
|
}
|
|
if !loggedEnd.Load() {
|
|
t.Error("end was not logged")
|
|
}
|
|
|
|
// Test error handling
|
|
handlerErr := func(_ context.Context) error {
|
|
return errors.New("job error")
|
|
}
|
|
|
|
wrappedErr := logging(handlerErr)
|
|
_ = wrappedErr(withJobName(context.Background(), "test-job-error"))
|
|
|
|
if !loggedError.Load() {
|
|
t.Error("error was not logged")
|
|
}
|
|
}
|
|
|
|
type testLogger struct {
|
|
onInfo func(msg string, args ...interface{})
|
|
onError func(msg string, args ...interface{})
|
|
}
|
|
|
|
func (l *testLogger) Info(msg string, args ...interface{}) {
|
|
if l.onInfo != nil {
|
|
l.onInfo(msg, args...)
|
|
}
|
|
}
|
|
|
|
func (l *testLogger) Error(msg string, args ...interface{}) {
|
|
if l.onError != nil {
|
|
l.onError(msg, args...)
|
|
}
|
|
}
|