memos/plugin/webhook/webhook.go

173 lines
4.8 KiB
Go

package webhook
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/json"
"encoding/hex"
"fmt"
"io"
"log/slog"
"os"
"net/http"
"net"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
var (
// timeout is the timeout for webhook request. Default to 30 seconds.
timeout = 30 * time.Second
)
type WebhookRequestPayload struct {
// The target URL for the webhook request.
URL string `json:"url"`
// The type of activity that triggered this webhook.
ActivityType string `json:"activityType"`
// The resource name of the creator. Format: users/{user}
Creator string `json:"creator"`
// The memo that triggered this webhook (if applicable).
Memo *v1pb.Memo `json:"memo"`
}
// Post posts the message to webhook endpoint.
func Post(requestPayload *WebhookRequestPayload) error {
// 基础 SSRF 防护:仅允许 http/https 且禁止回环/内网等目标。
if err := validateOutboundURL(requestPayload.URL); err != nil {
return errors.Wrapf(err, "invalid webhook target: %s", requestPayload.URL)
}
body, err := json.Marshal(requestPayload)
if err != nil {
return errors.Wrapf(err, "failed to marshal webhook request to %s", requestPayload.URL)
}
req, err := http.NewRequest("POST", requestPayload.URL, bytes.NewBuffer(body))
if err != nil {
return errors.Wrapf(err, "failed to construct webhook request to %s", requestPayload.URL)
}
req.Header.Set("Content-Type", "application/json")
// 可选 HMAC 签名:设置 MEMOS_OUTBOUND_WEBHOOK_HMAC_SECRET 即可启用。
if secret := strings.TrimSpace(os.Getenv("MEMOS_OUTBOUND_WEBHOOK_HMAC_SECRET")); secret != "" {
ts := time.Now().Unix()
msg := append([]byte(fmt.Sprintf("%d.", ts)), body...)
h := hmac.New(sha256.New, []byte(secret))
h.Write(msg)
sig := hex.EncodeToString(h.Sum(nil))
req.Header.Set("X-Memos-Signature", fmt.Sprintf("t=%d,v1=%s", ts, sig))
req.Header.Set("X-Memos-Source", "memos")
}
client := &http.Client{
Timeout: timeout,
}
resp, err := client.Do(req)
if err != nil {
return errors.Wrapf(err, "failed to post webhook to %s", requestPayload.URL)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return errors.Errorf("failed to post webhook %s, status code: %d, response body: %s", requestPayload.URL, resp.StatusCode, b)
}
response := &struct {
Code int `json:"code"`
Message string `json:"message"`
}{}
if err := json.Unmarshal(b, response); err != nil {
return errors.Wrapf(err, "failed to unmarshal webhook response from %s", requestPayload.URL)
}
if response.Code != 0 {
return errors.Errorf("receive error code sent by webhook server, code %d, msg: %s", response.Code, response.Message)
}
return nil
}
// PostAsync posts the message to webhook endpoint asynchronously.
// It spawns a new goroutine to handle the request and does not wait for the response.
func PostAsync(requestPayload *WebhookRequestPayload) {
go func() {
if err := Post(requestPayload); err != nil {
// Since we're in a goroutine, we can only log the error
slog.Warn("Failed to dispatch webhook asynchronously",
slog.String("url", requestPayload.URL),
slog.String("activityType", requestPayload.ActivityType),
slog.Any("err", err))
}
}()
}
// validateOutboundURL 基础 SSRF 防护(与 server/notification 略重复,保持插件自包含)。
func validateOutboundURL(raw string) error {
u, err := url.Parse(raw)
if err != nil {
return err
}
scheme := strings.ToLower(u.Scheme)
if scheme != "http" && scheme != "https" {
return errors.Errorf("unsupported scheme: %s", scheme)
}
host := u.Hostname()
if host == "" {
return errors.Errorf("empty host")
}
ips, err := net.LookupIP(host)
if err != nil {
return errors.Wrap(err, "dns lookup failed")
}
for _, ip := range ips {
if isDisallowedIP(ip) {
return errors.Errorf("disallowed target ip: %s", ip.String())
}
}
return nil
}
func isDisallowedIP(ip net.IP) bool {
if ip.IsLoopback() {
return true
}
privateCIDRs := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"127.0.0.0/8",
"169.254.169.254/32",
}
for _, cidr := range privateCIDRs {
_, block, _ := net.ParseCIDR(cidr)
if block.Contains(ip) {
return true
}
}
if ip.To4() == nil {
v6Blocks := []string{
"::1/128",
"fc00::/7",
"fe80::/10",
}
for _, c := range v6Blocks {
_, block, _ := net.ParseCIDR(c)
if block.Contains(ip) {
return true
}
}
}
return false
}