mirror of https://github.com/usememos/memos.git
207 lines
6.1 KiB
Go
207 lines
6.1 KiB
Go
package notification
|
|
|
|
// Notification service: central dispatch for memo-related webhooks (RAW/WeCom/Bark).
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/rand"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/usememos/memos/plugin/webhook"
|
|
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
|
"github.com/usememos/memos/store"
|
|
)
|
|
|
|
type Service struct {
|
|
store *store.Store
|
|
}
|
|
|
|
func NewService(store *store.Store) *Service {
|
|
return &Service{store: store}
|
|
}
|
|
|
|
// DispatchMemoWebhooks sends notifications based on user webhooks.
|
|
func (s *Service) DispatchMemoWebhooks(ctx context.Context, memo *v1pb.Memo, activityType string) error {
|
|
creatorID, err := ExtractUserIDFromName(memo.GetCreator())
|
|
if err != nil {
|
|
return fmt.Errorf("invalid memo creator: %w", err)
|
|
}
|
|
|
|
hooks, err := s.store.GetUserWebhooks(ctx, creatorID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, h := range hooks {
|
|
typ, target := classifyWebhook(h)
|
|
hostKey := hostKeyFor(target)
|
|
release := acquire(hostKey)
|
|
go func(typ webhookType, target string, hostKey string, release func()) {
|
|
defer release()
|
|
start := time.Now()
|
|
var err error
|
|
switch typ {
|
|
case webhookTypeWeCom:
|
|
err = sendWithRetry(ctx, hostKey, func() error { return sendWeCom(ctx, target, memo, activityType) })
|
|
case webhookTypeBark:
|
|
err = sendWithRetry(ctx, hostKey, func() error { return sendBark(ctx, target, memo, activityType) })
|
|
default:
|
|
payload, perr := convertMemoToWebhookPayload(memo)
|
|
if perr != nil {
|
|
slog.Warn("convert payload failed", slog.Any("err", perr))
|
|
return
|
|
}
|
|
payload.ActivityType = activityType
|
|
payload.URL = target
|
|
err = sendWithRetry(ctx, hostKey, func() error { return webhook.Post(payload) })
|
|
}
|
|
duration := time.Since(start)
|
|
if err != nil {
|
|
slog.Warn("Webhook dispatch failed", slog.String("type", string(typ)), slog.String("url", target), slog.Duration("latency", duration), slog.Any("err", err))
|
|
} else {
|
|
slog.Info("Webhook dispatched", slog.String("type", string(typ)), slog.String("url", target), slog.Duration("latency", duration))
|
|
}
|
|
}(typ, target, hostKey, release)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func classifyWebhook(h *storepb.WebhooksUserSetting_Webhook) (webhookType, string) {
|
|
raw := strings.TrimSpace(h.GetUrl())
|
|
if raw == "" {
|
|
return webhookTypeRAW, raw
|
|
}
|
|
if strings.HasPrefix(raw, "wecom://") {
|
|
return webhookTypeWeCom, strings.TrimPrefix(raw, "wecom://")
|
|
}
|
|
if strings.HasPrefix(raw, "bark://") {
|
|
return webhookTypeBark, strings.TrimPrefix(raw, "bark://")
|
|
}
|
|
if u, err := url.Parse(raw); err == nil {
|
|
host := strings.ToLower(u.Host)
|
|
if strings.Contains(host, "qyapi.weixin.qq.com") {
|
|
return webhookTypeWeCom, raw
|
|
}
|
|
if strings.Contains(host, "api.day.app") {
|
|
return webhookTypeBark, raw
|
|
}
|
|
}
|
|
return webhookTypeRAW, raw
|
|
}
|
|
|
|
// ExtractUserIDFromName parses "users/{id}" and returns id.
|
|
func ExtractUserIDFromName(name string) (int32, error) {
|
|
parts := strings.Split(name, "/")
|
|
if len(parts) != 2 || parts[0] != "users" {
|
|
return 0, fmt.Errorf("invalid user resource name: %s", name)
|
|
}
|
|
var id int32
|
|
var v int
|
|
_, err := fmt.Sscanf(parts[1], "%d", &v)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid user id: %s", parts[1])
|
|
}
|
|
id = int32(v)
|
|
return id, nil
|
|
}
|
|
|
|
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
|
|
creatorID, err := ExtractUserIDFromName(memo.GetCreator())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid memo creator: %w", err)
|
|
}
|
|
return &webhook.WebhookRequestPayload{
|
|
Creator: fmt.Sprintf("users/%d", creatorID),
|
|
Memo: memo,
|
|
}, nil
|
|
}
|
|
|
|
// --- limiter, retry, circuit breaker ---
|
|
|
|
var (
|
|
limiterMap sync.Map // key -> chan struct{}
|
|
cbMap sync.Map // key -> *cbState
|
|
maxConcurrentPerHost = 2
|
|
)
|
|
|
|
type cbState struct {
|
|
FailCount int
|
|
OpenUntil time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func hostKeyFor(target string) string {
|
|
if u, err := url.Parse(target); err == nil {
|
|
return strings.ToLower(u.Host)
|
|
}
|
|
return target
|
|
}
|
|
|
|
func acquire(key string) func() {
|
|
chAny, _ := limiterMap.LoadOrStore(key, make(chan struct{}, maxConcurrentPerHost))
|
|
ch := chAny.(chan struct{})
|
|
ch <- struct{}{}
|
|
return func() { <-ch }
|
|
}
|
|
|
|
func sendWithRetry(ctx context.Context, key string, fn func() error) error {
|
|
if isOpen(key) {
|
|
return fmt.Errorf("circuit open for %s", key)
|
|
}
|
|
var err error
|
|
backoffs := []time.Duration{500 * time.Millisecond, 1 * time.Second, 2 * time.Second}
|
|
for i := 0; i < len(backoffs)+1; i++ {
|
|
err = fn()
|
|
if err == nil {
|
|
recordSuccess(key)
|
|
return nil
|
|
}
|
|
recordFailure(key)
|
|
if i == len(backoffs) {
|
|
break
|
|
}
|
|
d := backoffs[i]
|
|
jitter := time.Duration(rand.Int63n(int64(d / 2)))
|
|
select {
|
|
case <-time.After(d + jitter):
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func isOpen(key string) bool {
|
|
v, _ := cbMap.LoadOrStore(key, &cbState{})
|
|
s := v.(*cbState)
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
return time.Now().Before(s.OpenUntil)
|
|
}
|
|
|
|
func recordFailure(key string) {
|
|
v, _ := cbMap.LoadOrStore(key, &cbState{})
|
|
s := v.(*cbState)
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.FailCount++
|
|
if s.FailCount >= 3 {
|
|
s.OpenUntil = time.Now().Add(1 * time.Minute)
|
|
s.FailCount = 0
|
|
}
|
|
}
|
|
|
|
func recordSuccess(key string) {
|
|
v, _ := cbMap.LoadOrStore(key, &cbState{})
|
|
s := v.(*cbState)
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.FailCount = 0
|
|
s.OpenUntil = time.Time{}
|
|
}
|