memos/server/notification/service.go

207 lines
6.1 KiB
Go

package notification
// Notification service: central dispatch for memo-related webhooks (RAW/WeCom/Bark).
import (
"context"
"fmt"
"log/slog"
"math/rand"
"net/url"
"strings"
"sync"
"time"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
type Service struct {
store *store.Store
}
func NewService(store *store.Store) *Service {
return &Service{store: store}
}
// DispatchMemoWebhooks sends notifications based on user webhooks.
func (s *Service) DispatchMemoWebhooks(ctx context.Context, memo *v1pb.Memo, activityType string) error {
creatorID, err := ExtractUserIDFromName(memo.GetCreator())
if err != nil {
return fmt.Errorf("invalid memo creator: %w", err)
}
hooks, err := s.store.GetUserWebhooks(ctx, creatorID)
if err != nil {
return err
}
for _, h := range hooks {
typ, target := classifyWebhook(h)
hostKey := hostKeyFor(target)
release := acquire(hostKey)
go func(typ webhookType, target string, hostKey string, release func()) {
defer release()
start := time.Now()
var err error
switch typ {
case webhookTypeWeCom:
err = sendWithRetry(ctx, hostKey, func() error { return sendWeCom(ctx, target, memo, activityType) })
case webhookTypeBark:
err = sendWithRetry(ctx, hostKey, func() error { return sendBark(ctx, target, memo, activityType) })
default:
payload, perr := convertMemoToWebhookPayload(memo)
if perr != nil {
slog.Warn("convert payload failed", slog.Any("err", perr))
return
}
payload.ActivityType = activityType
payload.URL = target
err = sendWithRetry(ctx, hostKey, func() error { return webhook.Post(payload) })
}
duration := time.Since(start)
if err != nil {
slog.Warn("Webhook dispatch failed", slog.String("type", string(typ)), slog.String("url", target), slog.Duration("latency", duration), slog.Any("err", err))
} else {
slog.Info("Webhook dispatched", slog.String("type", string(typ)), slog.String("url", target), slog.Duration("latency", duration))
}
}(typ, target, hostKey, release)
}
return nil
}
func classifyWebhook(h *storepb.WebhooksUserSetting_Webhook) (webhookType, string) {
raw := strings.TrimSpace(h.GetUrl())
if raw == "" {
return webhookTypeRAW, raw
}
if strings.HasPrefix(raw, "wecom://") {
return webhookTypeWeCom, strings.TrimPrefix(raw, "wecom://")
}
if strings.HasPrefix(raw, "bark://") {
return webhookTypeBark, strings.TrimPrefix(raw, "bark://")
}
if u, err := url.Parse(raw); err == nil {
host := strings.ToLower(u.Host)
if strings.Contains(host, "qyapi.weixin.qq.com") {
return webhookTypeWeCom, raw
}
if strings.Contains(host, "api.day.app") {
return webhookTypeBark, raw
}
}
return webhookTypeRAW, raw
}
// ExtractUserIDFromName parses "users/{id}" and returns id.
func ExtractUserIDFromName(name string) (int32, error) {
parts := strings.Split(name, "/")
if len(parts) != 2 || parts[0] != "users" {
return 0, fmt.Errorf("invalid user resource name: %s", name)
}
var id int32
var v int
_, err := fmt.Sscanf(parts[1], "%d", &v)
if err != nil {
return 0, fmt.Errorf("invalid user id: %s", parts[1])
}
id = int32(v)
return id, nil
}
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
creatorID, err := ExtractUserIDFromName(memo.GetCreator())
if err != nil {
return nil, fmt.Errorf("invalid memo creator: %w", err)
}
return &webhook.WebhookRequestPayload{
Creator: fmt.Sprintf("users/%d", creatorID),
Memo: memo,
}, nil
}
// --- limiter, retry, circuit breaker ---
var (
limiterMap sync.Map // key -> chan struct{}
cbMap sync.Map // key -> *cbState
maxConcurrentPerHost = 2
)
type cbState struct {
FailCount int
OpenUntil time.Time
mu sync.Mutex
}
func hostKeyFor(target string) string {
if u, err := url.Parse(target); err == nil {
return strings.ToLower(u.Host)
}
return target
}
func acquire(key string) func() {
chAny, _ := limiterMap.LoadOrStore(key, make(chan struct{}, maxConcurrentPerHost))
ch := chAny.(chan struct{})
ch <- struct{}{}
return func() { <-ch }
}
func sendWithRetry(ctx context.Context, key string, fn func() error) error {
if isOpen(key) {
return fmt.Errorf("circuit open for %s", key)
}
var err error
backoffs := []time.Duration{500 * time.Millisecond, 1 * time.Second, 2 * time.Second}
for i := 0; i < len(backoffs)+1; i++ {
err = fn()
if err == nil {
recordSuccess(key)
return nil
}
recordFailure(key)
if i == len(backoffs) {
break
}
d := backoffs[i]
jitter := time.Duration(rand.Int63n(int64(d / 2)))
select {
case <-time.After(d + jitter):
case <-ctx.Done():
return ctx.Err()
}
}
return err
}
func isOpen(key string) bool {
v, _ := cbMap.LoadOrStore(key, &cbState{})
s := v.(*cbState)
s.mu.Lock()
defer s.mu.Unlock()
return time.Now().Before(s.OpenUntil)
}
func recordFailure(key string) {
v, _ := cbMap.LoadOrStore(key, &cbState{})
s := v.(*cbState)
s.mu.Lock()
defer s.mu.Unlock()
s.FailCount++
if s.FailCount >= 3 {
s.OpenUntil = time.Now().Add(1 * time.Minute)
s.FailCount = 0
}
}
func recordSuccess(key string) {
v, _ := cbMap.LoadOrStore(key, &cbState{})
s := v.(*cbState)
s.mu.Lock()
defer s.mu.Unlock()
s.FailCount = 0
s.OpenUntil = time.Time{}
}