mirror of https://github.com/usememos/memos.git
Merge branch 'main' into Color-Picker
Signed-off-by: Ahmed Ashraf El-Gendy <108876019+Ahmed-Elgendy25@users.noreply.github.com>
This commit is contained in:
commit
a42b4f389a
|
|
@ -24,6 +24,8 @@ jobs:
|
|||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
tag: ${{ steps.version.outputs.tag }}
|
||||
major_minor: ${{ steps.version.outputs.major_minor }}
|
||||
is_prerelease: ${{ steps.version.outputs.is_prerelease }}
|
||||
steps:
|
||||
- name: Extract version
|
||||
id: version
|
||||
|
|
@ -34,11 +36,27 @@ jobs:
|
|||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
echo "tag=" >> "$GITHUB_OUTPUT"
|
||||
echo "version=manual-${GITHUB_SHA::7}" >> "$GITHUB_OUTPUT"
|
||||
echo "major_minor=" >> "$GITHUB_OUTPUT"
|
||||
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ ! "$REF_NAME" =~ ^v([0-9]+\.[0-9]+\.[0-9]+)(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Unsupported release tag format: $REF_NAME" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
version="${BASH_REMATCH[1]}${BASH_REMATCH[2]}"
|
||||
major_minor="${BASH_REMATCH[1]%.*}"
|
||||
is_prerelease=false
|
||||
if [ -n "${BASH_REMATCH[2]}" ]; then
|
||||
is_prerelease=true
|
||||
fi
|
||||
|
||||
echo "tag=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${REF_NAME#v}" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${version}" >> "$GITHUB_OUTPUT"
|
||||
echo "major_minor=${major_minor}" >> "$GITHUB_OUTPUT"
|
||||
echo "is_prerelease=${is_prerelease}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-frontend:
|
||||
name: Build Frontend
|
||||
|
|
@ -226,6 +244,7 @@ jobs:
|
|||
tag_name: ${{ needs.prepare.outputs.tag }}
|
||||
name: ${{ needs.prepare.outputs.tag }}
|
||||
generate_release_notes: true
|
||||
prerelease: ${{ needs.prepare.outputs.is_prerelease == 'true' }}
|
||||
files: artifacts/*
|
||||
|
||||
build-push:
|
||||
|
|
@ -301,7 +320,7 @@ jobs:
|
|||
retention-days: 1
|
||||
|
||||
merge-images:
|
||||
name: Publish Stable Image Tags
|
||||
name: Publish Release Image Tags
|
||||
needs: [prepare, build-push]
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -336,17 +355,28 @@ jobs:
|
|||
working-directory: /tmp/digests
|
||||
run: |
|
||||
version="${{ needs.prepare.outputs.version }}"
|
||||
major_minor=$(echo "$version" | cut -d. -f1,2)
|
||||
if [ "${{ needs.prepare.outputs.is_prerelease }}" = "true" ]; then
|
||||
docker buildx imagetools create \
|
||||
-t "neosmemo/memos:${version}" \
|
||||
-t "ghcr.io/usememos/memos:${version}" \
|
||||
$(printf 'neosmemo/memos@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
|
||||
docker buildx imagetools create \
|
||||
-t "neosmemo/memos:${version}" \
|
||||
-t "neosmemo/memos:${major_minor}" \
|
||||
-t "neosmemo/memos:${{ needs.prepare.outputs.major_minor }}" \
|
||||
-t "neosmemo/memos:stable" \
|
||||
-t "ghcr.io/usememos/memos:${version}" \
|
||||
-t "ghcr.io/usememos/memos:${major_minor}" \
|
||||
-t "ghcr.io/usememos/memos:${{ needs.prepare.outputs.major_minor }}" \
|
||||
-t "ghcr.io/usememos/memos:stable" \
|
||||
$(printf 'neosmemo/memos@sha256:%s ' *)
|
||||
|
||||
- name: Inspect images
|
||||
run: |
|
||||
docker buildx imagetools inspect neosmemo/memos:${{ needs.prepare.outputs.version }}
|
||||
if [ "${{ needs.prepare.outputs.is_prerelease }}" = "true" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
docker buildx imagetools inspect neosmemo/memos:stable
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
# Memo Filter Engine
|
||||
|
||||
This package houses the memo-only filter engine that turns CEL expressions into
|
||||
SQL fragments. The engine follows a three phase pipeline inspired by systems
|
||||
This package houses the memo-only filter engine that turns standard CEL syntax
|
||||
into SQL fragments for the subset of expressions supported by the memo schema.
|
||||
The engine follows a three phase pipeline inspired by systems
|
||||
such as Calcite or Prisma:
|
||||
|
||||
1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against
|
||||
the memo-specific environment declared in `schema.go`. Only fields that
|
||||
exist in the schema can surface in the filter.
|
||||
exist in the schema can surface in the filter, and non-standard legacy
|
||||
coercions are rejected.
|
||||
2. **Normalization** – the raw CEL AST is converted into an intermediate
|
||||
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
|
||||
conditions (logical operators, comparisons, list membership, etc.). This
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package filter
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
|||
return nil, errors.New("filter expression is empty")
|
||||
}
|
||||
|
||||
filter = normalizeLegacyFilter(filter)
|
||||
|
||||
ast, issues := e.env.Compile(filter)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
||||
|
|
@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) {
|
|||
})
|
||||
return defaultAttachmentInst, defaultAttachmentErr
|
||||
}
|
||||
|
||||
func normalizeLegacyFilter(expr string) string {
|
||||
expr = rewriteNumericLogicalOperand(expr, "&&")
|
||||
expr = rewriteNumericLogicalOperand(expr, "||")
|
||||
return expr
|
||||
}
|
||||
|
||||
func rewriteNumericLogicalOperand(expr, op string) string {
|
||||
var builder strings.Builder
|
||||
n := len(expr)
|
||||
i := 0
|
||||
var inQuote rune
|
||||
|
||||
for i < n {
|
||||
ch := expr[i]
|
||||
|
||||
if inQuote != 0 {
|
||||
builder.WriteByte(ch)
|
||||
if ch == '\\' && i+1 < n {
|
||||
builder.WriteByte(expr[i+1])
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if ch == byte(inQuote) {
|
||||
inQuote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inQuote = rune(ch)
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr[i:], op) {
|
||||
builder.WriteString(op)
|
||||
i += len(op)
|
||||
|
||||
// Preserve whitespace following the operator.
|
||||
wsStart := i
|
||||
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
builder.WriteString(expr[wsStart:i])
|
||||
|
||||
signStart := i
|
||||
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
||||
i++
|
||||
}
|
||||
if i > signStart {
|
||||
numLiteral := expr[signStart:i]
|
||||
fmt.Fprintf(&builder, "(%s != 0)", numLiteral)
|
||||
} else {
|
||||
builder.WriteString(expr[signStart:i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `pinned && 1`)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to compile filter")
|
||||
}
|
||||
|
||||
func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `1`)
|
||||
require.EqualError(t, err, "filter must evaluate to a boolean value")
|
||||
}
|
||||
|
|
@ -157,3 +157,10 @@ type ContainsPredicate struct {
|
|||
}
|
||||
|
||||
func (*ContainsPredicate) isPredicateExpr() {}
|
||||
|
||||
// EqualsPredicate represents t == "value".
|
||||
type EqualsPredicate struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*EqualsPredicate) isPredicateExpr() {}
|
||||
|
|
|
|||
|
|
@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v, ok := val.(bool); ok {
|
||||
return &ConstantCondition{Value: v}, nil
|
||||
case int64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
case float64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
default:
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
}
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
name := v.IdentExpr.GetName()
|
||||
field, ok := schema.Field(name)
|
||||
|
|
@ -504,6 +498,8 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
|
|||
|
||||
// Handle different predicate functions
|
||||
switch predicateCall.Function {
|
||||
case "_==_":
|
||||
return buildEqualsPredicate(predicateCall, comp.IterVar)
|
||||
case "startsWith":
|
||||
return buildStartsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "endsWith":
|
||||
|
|
@ -511,10 +507,44 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
|
|||
case "contains":
|
||||
return buildContainsPredicate(predicateCall, comp.IterVar)
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
|
||||
return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEqualsPredicate extracts the value from t == "value".
|
||||
func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("equality predicate expects exactly two arguments")
|
||||
}
|
||||
|
||||
var constExpr *exprv1.Expr
|
||||
switch {
|
||||
case isIterVarExpr(call.Args[0], iterVar):
|
||||
constExpr = call.Args[1]
|
||||
case isIterVarExpr(call.Args[1], iterVar):
|
||||
constExpr = call.Args[0]
|
||||
default:
|
||||
return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
value, err := getConstValue(constExpr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "equality argument must be a constant string")
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("equality argument must be a string")
|
||||
}
|
||||
|
||||
return &EqualsPredicate{Value: valueStr}, nil
|
||||
}
|
||||
|
||||
func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool {
|
||||
target := expr.GetIdentExpr()
|
||||
return target != nil && target.GetName() == iterVar
|
||||
}
|
||||
|
||||
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
|
||||
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
// Verify the target is the iteration variable
|
||||
|
|
|
|||
|
|
@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
|
|||
|
||||
// Render based on predicate type
|
||||
switch pred := cond.Predicate.(type) {
|
||||
case *EqualsPredicate:
|
||||
return r.renderTagEquals(field, pred.Value, cond.Kind)
|
||||
case *StartsWithPredicate:
|
||||
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
|
||||
case *EndsWithPredicate:
|
||||
|
|
@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
|
|||
}
|
||||
}
|
||||
|
||||
// renderTagEquals generates SQL for tags.exists(t, t == "value").
|
||||
func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value))
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
|
||||
case DialectPostgres:
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value)))
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
|
||||
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ message InstanceSetting {
|
|||
|
||||
// Metadata for a tag.
|
||||
message TagMetadata {
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
google.type.Color background_color = 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -759,7 +759,8 @@ func (x *InstanceSetting_MemoRelatedSetting) GetReactions() []string {
|
|||
// Metadata for a tag.
|
||||
type InstanceSetting_TagMetadata struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
BackgroundColor *color.Color `protobuf:"bytes,1,opt,name=background_color,json=backgroundColor,proto3" json:"background_color,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
|
|||
|
|
@ -2396,7 +2396,12 @@ components:
|
|||
backgroundColor:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/Color'
|
||||
description: Background color for the tag label.
|
||||
description: |-
|
||||
Optional background color for the tag label.
|
||||
When unset, the default tag color is used.
|
||||
blurContent:
|
||||
type: boolean
|
||||
description: Whether memos with this tag should have their content blurred.
|
||||
description: Metadata for a tag.
|
||||
InstanceSetting_TagsSetting:
|
||||
type: object
|
||||
|
|
|
|||
|
|
@ -754,7 +754,8 @@ func (x *InstanceMemoRelatedSetting) GetReactions() []string {
|
|||
|
||||
type InstanceTagMetadata struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
BackgroundColor *color.Color `protobuf:"bytes,1,opt,name=background_color,json=backgroundColor,proto3" json:"background_color,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
|
|||
|
|
@ -111,7 +111,8 @@ message InstanceMemoRelatedSetting {
|
|||
}
|
||||
|
||||
message InstanceTagMetadata {
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
google.type.Color background_color = 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -421,11 +421,10 @@ func validateInstanceTagsSetting(setting *v1pb.InstanceSetting_TagsSetting) erro
|
|||
if metadata == nil {
|
||||
return errors.Errorf("tag metadata is required for %q", tag)
|
||||
}
|
||||
if metadata.GetBackgroundColor() == nil {
|
||||
return errors.Errorf("background_color is required for %q", tag)
|
||||
}
|
||||
if err := validateInstanceColor(metadata.GetBackgroundColor()); err != nil {
|
||||
return errors.Wrapf(err, "background_color for %q", tag)
|
||||
if metadata.GetBackgroundColor() != nil {
|
||||
if err := validateInstanceColor(metadata.GetBackgroundColor()); err != nil {
|
||||
return errors.Wrapf(err, "background_color for %q", tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -35,20 +35,36 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if err := s.setMemoAttachmentsInternal(ctx, memo, request.Attachments); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updatedMemo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to build updated memo state")
|
||||
}
|
||||
s.dispatchMemoUpdatedSideEffects(ctx, updatedMemo, parentMemo, memoMessage)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *store.Memo, requestAttachments []*v1pb.Attachment) error {
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
return status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
// Delete attachments that are not in the request.
|
||||
for _, attachment := range attachments {
|
||||
found := false
|
||||
for _, requestAttachment := range request.Attachments {
|
||||
for _, requestAttachment := range requestAttachments {
|
||||
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
return status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
if attachment.UID == requestAttachmentUID {
|
||||
found = true
|
||||
|
|
@ -60,24 +76,24 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
ID: int32(attachment.ID),
|
||||
MemoID: &memo.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
return status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(request.Attachments)
|
||||
slices.Reverse(requestAttachments)
|
||||
// Update attachments' memo_id in the request.
|
||||
for index, attachment := range request.Attachments {
|
||||
for index, attachment := range requestAttachments {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
return status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
return status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if tempAttachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
|
||||
return status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
|
||||
}
|
||||
updatedTs := time.Now().Unix() + int64(index)
|
||||
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
|
|
@ -85,11 +101,11 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
MemoID: &memo.ID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
return status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
|
||||
|
|
|
|||
|
|
@ -35,18 +35,34 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
|
|||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if err := s.setMemoRelationsInternal(ctx, memo, request.Relations); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updatedMemo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to build updated memo state")
|
||||
}
|
||||
s.dispatchMemoUpdatedSideEffects(ctx, updatedMemo, parentMemo, memoMessage)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) setMemoRelationsInternal(ctx context.Context, memo *store.Memo, relations []*v1pb.MemoRelation) error {
|
||||
referenceType := store.MemoRelationReference
|
||||
// Delete all reference relations first.
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
Type: &referenceType,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
return status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
}
|
||||
|
||||
for _, relation := range request.Relations {
|
||||
for _, relation := range relations {
|
||||
// Ignore reflexive relations.
|
||||
if request.Name == relation.RelatedMemo.Name {
|
||||
if buildMemoName(memo.UID) == relation.RelatedMemo.Name {
|
||||
continue
|
||||
}
|
||||
// Ignore comment relations as there's no need to update a comment's relation.
|
||||
|
|
@ -56,22 +72,22 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
|
|||
}
|
||||
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
return status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo")
|
||||
return status.Errorf(codes.Internal, "failed to get related memo")
|
||||
}
|
||||
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: convertMemoRelationTypeToStore(relation.Type),
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
return status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
|
||||
|
|
|
|||
|
|
@ -469,19 +469,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
payload.Location = convertLocationToStore(request.Memo.Location)
|
||||
update.Payload = payload
|
||||
} else if path == "attachments" {
|
||||
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Attachments: request.Memo.Attachments,
|
||||
})
|
||||
if err != nil {
|
||||
if err := s.setMemoAttachmentsInternal(ctx, memo, request.Memo.Attachments); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo attachments")
|
||||
}
|
||||
} else if path == "relations" {
|
||||
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Relations: request.Memo.Relations,
|
||||
})
|
||||
if err != nil {
|
||||
if err := s.setMemoRelationsInternal(ctx, memo, request.Memo.Relations); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo relations")
|
||||
}
|
||||
}
|
||||
|
|
@ -497,44 +489,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Memo.Name,
|
||||
})
|
||||
memo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
return nil, errors.Wrap(err, "failed to build updated memo state")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
var parentMemo *store.Memo
|
||||
if memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
// Try to dispatch webhook when memo is updated.
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
Parent: memoMessage.GetParent(),
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, parentMemo),
|
||||
})
|
||||
s.dispatchMemoUpdatedSideEffects(ctx, memo, parentMemo, memoMessage)
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) touchMemoUpdatedTimestamp(ctx context.Context, memoID int32) error {
|
||||
updatedTs := time.Now().Unix()
|
||||
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memoID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to update memo timestamp")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) buildUpdatedMemoState(ctx context.Context, memoID int32) (*store.Memo, *store.Memo, *v1pb.Memo, error) {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID})
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, nil, nil, errors.New("memo not found")
|
||||
}
|
||||
|
||||
memoName := buildMemoName(memo.UID)
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &memoName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to list reactions")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to list attachments")
|
||||
}
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
||||
var parentMemo *store.Memo
|
||||
if memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
|
||||
return memo, parentMemo, memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) dispatchMemoUpdatedSideEffects(ctx context.Context, memo *store.Memo, parentMemo *store.Memo, memoMessage *v1pb.Memo) {
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
Parent: memoMessage.GetParent(),
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, parentMemo),
|
||||
})
|
||||
}
|
||||
|
|
@ -187,3 +187,86 @@ func TestDeleteMemoReaction_SSEEvent(t *testing.T) {
|
|||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSetMemoAttachments_EmitsMemoUpdatedSSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "memo with attachments", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
attachment, err := svc.CreateAttachment(uctx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte("hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.SetMemoAttachments(uctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*v1pb.Attachment{
|
||||
{Name: attachment.Name},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"memo.updated"`)
|
||||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSetMemoRelations_EmitsMemoUpdatedSSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo1, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "memo one", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
memo2, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "memo two", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.SetMemoRelations(uctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: memo1.Name,
|
||||
Relations: []*v1pb.MemoRelation{
|
||||
{
|
||||
RelatedMemo: &v1pb.MemoRelation_Memo{Name: memo2.Name},
|
||||
Type: v1pb.MemoRelation_REFERENCE,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"memo.updated"`)
|
||||
assert.Contains(t, payload, memo1.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -318,6 +318,34 @@ func TestUpdateInstanceSetting(t *testing.T) {
|
|||
require.Contains(t, err.Error(), "invalid instance setting")
|
||||
})
|
||||
|
||||
t.Run("UpdateInstanceSetting - tags setting without color", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := ts.Service.UpdateInstanceSetting(ts.CreateUserContext(ctx, hostUser.ID), &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: &v1pb.InstanceSetting{
|
||||
Name: "instance/settings/TAGS",
|
||||
Value: &v1pb.InstanceSetting_TagsSetting_{
|
||||
TagsSetting: &v1pb.InstanceSetting_TagsSetting{
|
||||
Tags: map[string]*v1pb.InstanceSetting_TagMetadata{
|
||||
"spoiler": {
|
||||
BlurContent: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.GetTagsSetting())
|
||||
require.Contains(t, resp.GetTagsSetting().GetTags(), "spoiler")
|
||||
require.Nil(t, resp.GetTagsSetting().GetTags()["spoiler"].GetBackgroundColor())
|
||||
require.True(t, resp.GetTagsSetting().GetTags()["spoiler"].GetBlurContent())
|
||||
})
|
||||
|
||||
t.Run("UpdateInstanceSetting - notification setting password is write-only", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestDeleteMemoShare_VerifiesShareBelongsToMemo(t *testing.T) {
|
||||
|
|
@ -107,3 +109,107 @@ func TestGetMemoByShare_IncludesReactions(t *testing.T) {
|
|||
require.Equal(t, "👍", sharedMemo.Reactions[0].ReactionType)
|
||||
require.Equal(t, memo.Name, sharedMemo.Reactions[0].ContentId)
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForUnknownShare(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
_, err := ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: "missing-share-token",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForExpiredShare(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "share-expired")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "memo with expired share",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredTs := time.Now().Add(-time.Hour).Unix()
|
||||
expiredShare, err := ts.Store.CreateMemoShare(ctx, &store.MemoShare{
|
||||
UID: "expired-share-token",
|
||||
MemoID: parseMemoIDFromNameForTest(t, ts, memo.Name),
|
||||
CreatorID: user.ID,
|
||||
ExpiresTs: &expiredTs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: expiredShare.UID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForArchivedMemo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "share-archived")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
memoResp, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "memo that will be archived",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
share, err := ts.Service.CreateMemoShare(userCtx, &apiv1.CreateMemoShareRequest{
|
||||
Parent: memoResp.Name,
|
||||
MemoShare: &apiv1.MemoShare{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoID := parseMemoIDFromNameForTest(t, ts, memoResp.Name)
|
||||
memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
archived := store.Archived
|
||||
err = ts.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
RowStatus: &archived,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
shareToken := share.Name[strings.LastIndex(share.Name, "/")+1:]
|
||||
_, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: shareToken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func parseMemoIDFromNameForTest(t *testing.T, ts *TestService, memoName string) int32 {
|
||||
t.Helper()
|
||||
|
||||
memoUID, ok := strings.CutPrefix(memoName, "memos/")
|
||||
require.True(t, ok, "memo name must start with memos/: %s", memoName)
|
||||
|
||||
memo, err := ts.Store.GetMemo(context.Background(), &store.FindMemo{UID: &memoUID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
return memo.ID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ This package implements a [Model Context Protocol (MCP)](https://modelcontextpro
|
|||
```
|
||||
POST /mcp (tool calls, initialize)
|
||||
GET /mcp (optional SSE stream for server-to-client messages)
|
||||
DELETE /mcp (optional session termination)
|
||||
```
|
||||
|
||||
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
|
||||
|
|
@ -24,13 +25,22 @@ The server advertises the following MCP capabilities:
|
|||
|
||||
## Authentication
|
||||
|
||||
Every request must include a Personal Access Token (PAT):
|
||||
Public reads can be used without authentication. Personal Access Tokens (PATs) or short-lived JWT session tokens are required for:
|
||||
|
||||
- Reading non-public memos or attachments
|
||||
- Any tool that mutates data
|
||||
|
||||
When authenticating, send a Bearer token:
|
||||
|
||||
```
|
||||
Authorization: Bearer <your-PAT>
|
||||
```
|
||||
|
||||
PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests without a valid token receive `HTTP 401`.
|
||||
PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests with an invalid token receive `HTTP 401`.
|
||||
|
||||
## Origin Validation
|
||||
|
||||
For Streamable HTTP safety, requests with an `Origin` header must be same-origin with the current request host or match the configured `instance-url`. Requests without an `Origin` header, such as desktop MCP clients and CLI tools, are allowed.
|
||||
|
||||
## Tools
|
||||
|
||||
|
|
@ -38,7 +48,7 @@ PATs are long-lived tokens created in Settings → My Account → Access Tokens.
|
|||
|
||||
| Tool | Description | Required params | Optional params |
|
||||
|---|---|---|---|
|
||||
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) |
|
||||
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) |
|
||||
| `get_memo` | Get a single memo | `name` | — |
|
||||
| `search_memos` | Full-text search | `query` | — |
|
||||
| `create_memo` | Create a memo | `content` | `visibility` |
|
||||
|
|
@ -60,15 +70,15 @@ PATs are long-lived tokens created in Settings → My Account → Access Tokens.
|
|||
| `list_attachments` | List user's attachments | — | `page_size`, `page`, `memo` |
|
||||
| `get_attachment` | Get attachment metadata | `name` | — |
|
||||
| `delete_attachment` | Delete an attachment | `name` | — |
|
||||
| `link_attachment_to_memo` | Link attachment to memo | `name`, `memo` | — |
|
||||
| `link_attachment_to_memo` | Link attachment to a memo you own | `name`, `memo` | — |
|
||||
|
||||
### Relation Tools
|
||||
|
||||
| Tool | Description | Required params | Optional params |
|
||||
|---|---|---|---|
|
||||
| `list_memo_relations` | List relations (refs + comments) | `name` | `type` |
|
||||
| `create_memo_relation` | Create a reference relation | `name`, `related_memo` | — |
|
||||
| `delete_memo_relation` | Delete a reference relation | `name`, `related_memo` | — |
|
||||
| `create_memo_relation` | Create a reference relation from a memo you own to a memo you can read | `name`, `related_memo` | — |
|
||||
| `delete_memo_relation` | Delete a reference relation from a memo you own | `name`, `related_memo` | — |
|
||||
|
||||
### Reaction Tools
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// checkMemoAccess returns an error if the caller cannot read the memo.
|
||||
// userID == 0 means anonymous.
|
||||
func checkMemoAccess(memo *store.Memo, userID int32) error {
|
||||
if memo.RowStatus == store.Archived && memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
|
||||
switch memo.Visibility {
|
||||
case store.Protected:
|
||||
if userID == 0 {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
case store.Private:
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
default:
|
||||
// store.Public and any unknown visibility: allow.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkMemoOwnership(memo *store.Memo, userID int32) error {
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasMemoOwnership(memo *store.Memo, userID int32) bool {
|
||||
return memo.CreatorID == userID
|
||||
}
|
||||
|
||||
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||
func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) {
|
||||
if rowStatus != nil && *rowStatus == store.Archived {
|
||||
if userID == 0 {
|
||||
impossibleCreatorID := int32(-1)
|
||||
find.CreatorID = &impossibleCreatorID
|
||||
return
|
||||
}
|
||||
find.CreatorID = &userID
|
||||
return
|
||||
}
|
||||
if userID == 0 {
|
||||
find.VisibilityList = []store.Visibility{store.Public}
|
||||
return
|
||||
}
|
||||
find.Filters = append(find.Filters, "creator_id == "+itoa32(userID)+` || visibility in ["PUBLIC", "PROTECTED"]`)
|
||||
}
|
||||
|
||||
func (s *MCPService) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment, userID int32) error {
|
||||
if attachment.CreatorID == userID {
|
||||
return nil
|
||||
}
|
||||
if attachment.MemoID == nil {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get linked memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return errors.New("linked memo not found")
|
||||
}
|
||||
return checkMemoAccess(memo, userID)
|
||||
}
|
||||
|
||||
func (s *MCPService) isAllowedOrigin(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil || originURL.Scheme == "" || originURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if sameOriginHost(originURL.Host, r.Host) {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.profile.InstanceURL == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
instanceURL, err := url.Parse(s.profile.InstanceURL)
|
||||
if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && sameOriginHost(originURL.Host, instanceURL.Host)
|
||||
}
|
||||
|
||||
func sameOriginHost(a, b string) bool {
|
||||
return strings.EqualFold(a, b)
|
||||
}
|
||||
|
||||
func itoa32(v int32) string {
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
}
|
||||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
|
|
@ -44,11 +43,22 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
|
|||
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
|
||||
|
||||
mcpGroup := echoServer.Group("")
|
||||
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
if !s.isAllowedOrigin(c.Request()) {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"message": "invalid origin"})
|
||||
}
|
||||
if origin := c.Request().Header.Get("Origin"); origin != "" {
|
||||
headers := c.Response().Header()
|
||||
headers.Set("Vary", "Origin")
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID")
|
||||
headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
||||
if c.Request().Method == http.MethodOptions {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,275 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
type testMCPService struct {
|
||||
service *MCPService
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
func newTestMCPService(t *testing.T) *testMCPService {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
stores := teststore.NewTestingStore(ctx, t)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, stores.Close())
|
||||
})
|
||||
|
||||
svc := NewMCPService(&profile.Profile{
|
||||
Driver: "sqlite",
|
||||
InstanceURL: "https://notes.example.com",
|
||||
}, stores, "test-secret")
|
||||
return &testMCPService{
|
||||
service: svc,
|
||||
store: stores,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testMCPService) createUser(t *testing.T, username string) *store.User {
|
||||
t.Helper()
|
||||
|
||||
user, err := s.store.CreateUser(context.Background(), &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleUser,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return user
|
||||
}
|
||||
|
||||
func (s *testMCPService) createMemo(t *testing.T, creatorID int32, visibility store.Visibility, content string) *store.Memo {
|
||||
t.Helper()
|
||||
|
||||
memo, err := s.store.CreateMemo(context.Background(), &store.Memo{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: creatorID,
|
||||
RowStatus: store.Normal,
|
||||
Visibility: visibility,
|
||||
Content: content,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return memo
|
||||
}
|
||||
|
||||
func (s *testMCPService) archiveMemo(t *testing.T, memoID int32) {
|
||||
t.Helper()
|
||||
|
||||
rowStatus := store.Archived
|
||||
require.NoError(t, s.store.UpdateMemo(context.Background(), &store.UpdateMemo{
|
||||
ID: memoID,
|
||||
RowStatus: &rowStatus,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *testMCPService) createAttachment(t *testing.T, creatorID int32, memoID *int32) *store.Attachment {
|
||||
t.Helper()
|
||||
|
||||
attachment, err := s.store.CreateAttachment(context.Background(), &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: creatorID,
|
||||
Filename: "note.txt",
|
||||
Type: "text/plain",
|
||||
Size: 4,
|
||||
StorageType: storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED,
|
||||
Reference: "db://attachment/note.txt",
|
||||
MemoID: memoID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return attachment
|
||||
}
|
||||
|
||||
func withUser(ctx context.Context, userID int32) context.Context {
|
||||
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||
}
|
||||
|
||||
func toolRequest(name string, arguments map[string]any) mcp.CallToolRequest {
|
||||
return mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func firstText(t *testing.T, result *mcp.CallToolResult) string {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, result.Content)
|
||||
text, ok := result.Content[0].(mcp.TextContent)
|
||||
require.True(t, ok)
|
||||
return text.Text
|
||||
}
|
||||
|
||||
func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
other := ts.createUser(t, "other")
|
||||
|
||||
memo := ts.createMemo(t, owner.ID, store.Public, "archived")
|
||||
ts.archiveMemo(t, memo.ID)
|
||||
|
||||
ctx := withUser(context.Background(), other.ID)
|
||||
result, err := ts.service.handleGetMemo(ctx, toolRequest("get_memo", map[string]any{
|
||||
"name": "memos/" + memo.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Contains(t, firstText(t, result), "permission denied")
|
||||
|
||||
_, err = ts.service.handleReadMemoResource(ctx, mcp.ReadResourceRequest{
|
||||
Params: mcp.ReadResourceParams{
|
||||
URI: "memo://memos/" + memo.UID,
|
||||
},
|
||||
})
|
||||
require.ErrorContains(t, err, "permission denied")
|
||||
}
|
||||
|
||||
func TestHandleListMemosArchivedOnlyReturnsCreatorMemos(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
other := ts.createUser(t, "other")
|
||||
|
||||
ownerMemo := ts.createMemo(t, owner.ID, store.Public, "owner archived")
|
||||
ts.archiveMemo(t, ownerMemo.ID)
|
||||
otherMemo := ts.createMemo(t, other.ID, store.Public, "other archived")
|
||||
ts.archiveMemo(t, otherMemo.ID)
|
||||
|
||||
result, err := ts.service.handleListMemos(withUser(context.Background(), owner.ID), toolRequest("list_memos", map[string]any{
|
||||
"state": "ARCHIVED",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsError)
|
||||
|
||||
var payload struct {
|
||||
Memos []memoJSON `json:"memos"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &payload))
|
||||
require.Len(t, payload.Memos, 1)
|
||||
require.Equal(t, "memos/"+ownerMemo.UID, payload.Memos[0].Name)
|
||||
|
||||
anonResult, err := ts.service.handleListMemos(context.Background(), toolRequest("list_memos", map[string]any{
|
||||
"state": "ARCHIVED",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, json.Unmarshal([]byte(firstText(t, anonResult)), &payload))
|
||||
require.Empty(t, payload.Memos)
|
||||
}
|
||||
|
||||
func TestHandleListMemoRelationsFiltersUnreadableTargets(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
privateUser := ts.createUser(t, "private-user")
|
||||
publicUser := ts.createUser(t, "public-user")
|
||||
|
||||
source := ts.createMemo(t, owner.ID, store.Public, "source")
|
||||
privateTarget := ts.createMemo(t, privateUser.ID, store.Private, "private")
|
||||
publicTarget := ts.createMemo(t, publicUser.ID, store.Public, "public")
|
||||
|
||||
_, err := ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
||||
MemoID: source.ID,
|
||||
RelatedMemoID: privateTarget.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
||||
MemoID: source.ID,
|
||||
RelatedMemoID: publicTarget.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
||||
"name": "memos/" + source.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsError)
|
||||
|
||||
var relations []relationJSON
|
||||
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &relations))
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, "memos/"+publicTarget.UID, relations[0].RelatedMemo)
|
||||
|
||||
denied, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
||||
"name": "memos/" + privateTarget.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, denied.IsError)
|
||||
require.Contains(t, firstText(t, denied), "permission denied")
|
||||
}
|
||||
|
||||
func TestHandleLinkAttachmentToMemoRequiresMemoOwnership(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
attachmentOwner := ts.createUser(t, "attachment-owner")
|
||||
memoOwner := ts.createUser(t, "memo-owner")
|
||||
|
||||
attachment := ts.createAttachment(t, attachmentOwner.ID, nil)
|
||||
memo := ts.createMemo(t, memoOwner.ID, store.Public, "target")
|
||||
|
||||
result, err := ts.service.handleLinkAttachmentToMemo(withUser(context.Background(), attachmentOwner.ID), toolRequest("link_attachment_to_memo", map[string]any{
|
||||
"name": "attachments/" + attachment.UID,
|
||||
"memo": "memos/" + memo.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Contains(t, firstText(t, result), "permission denied")
|
||||
}
|
||||
|
||||
func TestHandleGetAttachmentDeniesArchivedLinkedMemoToNonCreator(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
other := ts.createUser(t, "other")
|
||||
|
||||
memo := ts.createMemo(t, owner.ID, store.Public, "memo")
|
||||
ts.archiveMemo(t, memo.ID)
|
||||
attachment := ts.createAttachment(t, owner.ID, &memo.ID)
|
||||
|
||||
result, err := ts.service.handleGetAttachment(withUser(context.Background(), other.ID), toolRequest("get_attachment", map[string]any{
|
||||
"name": "attachments/" + attachment.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Contains(t, firstText(t, result), "permission denied")
|
||||
}
|
||||
|
||||
func TestIsAllowedOrigin(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
|
||||
t.Run("allow missing origin", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||
require.True(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
|
||||
t.Run("allow same origin as request host", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||
req.Header.Set("Origin", "http://localhost:5230")
|
||||
require.True(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
|
||||
t.Run("allow configured instance origin", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://127.0.0.1:5230/mcp", nil)
|
||||
req.Host = "127.0.0.1:5230"
|
||||
req.Header.Set("Origin", "https://notes.example.com")
|
||||
require.True(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
|
||||
t.Run("reject cross origin", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||
req.Header.Set("Origin", "https://evil.example.com")
|
||||
require.False(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
}
|
||||
|
|
@ -216,21 +216,8 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe
|
|||
return mcp.NewToolResultError("attachment not found"), nil
|
||||
}
|
||||
|
||||
// Check access: creator can always access; linked memo visibility applies otherwise.
|
||||
if attachment.CreatorID != userID {
|
||||
if attachment.MemoID != nil {
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get linked memo: %v", err)), nil
|
||||
}
|
||||
if memo != nil {
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
}
|
||||
if err := s.checkAttachmentAccess(ctx, attachment, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
result, err := storeAttachmentToJSON(ctx, s.store, attachment)
|
||||
|
|
@ -302,6 +289,9 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoOwnership(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
ID: attachment.ID,
|
||||
|
|
|
|||
|
|
@ -168,33 +168,6 @@ func storeMemoToJSONWithUsernames(m *store.Memo, usernamesByID map[int32]string)
|
|||
return j, nil
|
||||
}
|
||||
|
||||
// checkMemoAccess returns an error if the caller cannot read memo.
|
||||
// userID == 0 means anonymous.
|
||||
func checkMemoAccess(memo *store.Memo, userID int32) error {
|
||||
switch memo.Visibility {
|
||||
case store.Protected:
|
||||
if userID == 0 {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
case store.Private:
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
default:
|
||||
// store.Public and any unknown visibility: allow
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||
func applyVisibilityFilter(find *store.FindMemo, userID int32) {
|
||||
if userID == 0 {
|
||||
find.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID))
|
||||
}
|
||||
}
|
||||
|
||||
// parseMemoUID extracts the UID from a "memos/<uid>" resource name.
|
||||
func parseMemoUID(name string) (string, error) {
|
||||
uid, ok := strings.CutPrefix(name, "memos/")
|
||||
|
|
@ -250,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
|
|||
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
|
||||
),
|
||||
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
), s.handleListMemos)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("get_memo",
|
||||
|
|
@ -337,7 +310,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
|
|||
Offset: &offset,
|
||||
OrderByPinned: req.GetBool("order_by_pinned", false),
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
applyVisibilityFilter(find, userID, rowStatus)
|
||||
if filter := req.GetString("filter", ""); filter != "" {
|
||||
find.Filters = append(find.Filters, filter)
|
||||
}
|
||||
|
|
@ -465,8 +438,8 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if memo.CreatorID != userID {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
if err := checkMemoOwnership(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
update := &store.UpdateMemo{ID: memo.ID}
|
||||
|
|
@ -533,8 +506,8 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if memo.CreatorID != userID {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
if err := checkMemoOwnership(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
|
|
@ -561,7 +534,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ
|
|||
Offset: &zero,
|
||||
Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)},
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
applyVisibilityFilter(find, userID, find.RowStatus)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
|
@ -40,6 +41,8 @@ func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) {
|
|||
}
|
||||
|
||||
func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
|
|
@ -52,6 +55,9 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
find := &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memo.ID},
|
||||
|
|
@ -85,21 +91,24 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
|
|||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memos: %v", err)), nil
|
||||
}
|
||||
uidByID := make(map[int32]string, len(memos))
|
||||
memoByID := make(map[int32]*store.Memo, len(memos))
|
||||
for _, m := range memos {
|
||||
uidByID[m.ID] = m.UID
|
||||
memoByID[m.ID] = m
|
||||
}
|
||||
|
||||
results := make([]relationJSON, 0, len(relations))
|
||||
for _, r := range relations {
|
||||
memoUID, ok1 := uidByID[r.MemoID]
|
||||
relatedUID, ok2 := uidByID[r.RelatedMemoID]
|
||||
srcMemo, ok1 := memoByID[r.MemoID]
|
||||
relatedMemo, ok2 := memoByID[r.RelatedMemoID]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
if checkMemoAccess(srcMemo, userID) != nil || checkMemoAccess(relatedMemo, userID) != nil {
|
||||
continue
|
||||
}
|
||||
results = append(results, relationJSON{
|
||||
Memo: "memos/" + memoUID,
|
||||
RelatedMemo: "memos/" + relatedUID,
|
||||
Memo: "memos/" + srcMemo.UID,
|
||||
RelatedMemo: "memos/" + relatedMemo.UID,
|
||||
Type: string(r.Type),
|
||||
})
|
||||
}
|
||||
|
|
@ -133,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if srcMemo == nil {
|
||||
return mcp.NewToolResultError("source memo not found"), nil
|
||||
}
|
||||
if srcMemo.CreatorID != userID {
|
||||
if !hasMemoOwnership(srcMemo, userID) {
|
||||
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
|
||||
}
|
||||
|
||||
|
|
@ -144,6 +153,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if dstMemo == nil {
|
||||
return mcp.NewToolResultError("related memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(dstMemo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: srcMemo.ID,
|
||||
|
|
@ -187,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if srcMemo == nil {
|
||||
return mcp.NewToolResultError("source memo not found"), nil
|
||||
}
|
||||
if srcMemo.CreatorID != userID {
|
||||
if !hasMemoOwnership(srcMemo, userID) {
|
||||
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
|
||||
}
|
||||
|
||||
|
|
@ -198,6 +210,9 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if dstMemo == nil {
|
||||
return mcp.NewToolResultError("related memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(dstMemo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
refType := store.MemoRelationReference
|
||||
if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest)
|
|||
ExcludeContent: true,
|
||||
RowStatus: &rowStatus,
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
applyVisibilityFilter(find, userID, find.RowStatus)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -12,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
|
|
@ -20,7 +23,6 @@ import (
|
|||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
// Database drivers for connection verification.
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
|
@ -31,6 +33,9 @@ const (
|
|||
// Memos container settings for migration testing.
|
||||
MemosDockerImage = "neosmemo/memos"
|
||||
StableMemosVersion = "stable" // Always points to the latest stable release
|
||||
|
||||
mysqlNetworkAlias = "memos-mysql"
|
||||
postgresNetworkAlias = "memos-postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -62,12 +67,23 @@ func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error)
|
|||
return testDockerNetwork.Load(), networkErr
|
||||
}
|
||||
|
||||
func requireTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create test network")
|
||||
}
|
||||
if nw == nil {
|
||||
return nil, errors.New("test network is unavailable")
|
||||
}
|
||||
return nw, nil
|
||||
}
|
||||
|
||||
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
|
||||
func GetMySQLDSN(t *testing.T) string {
|
||||
ctx := context.Background()
|
||||
|
||||
mysqlOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
|
@ -86,7 +102,7 @@ func GetMySQLDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("3306/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
network.WithNetwork([]string{mysqlNetworkAlias}, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start MySQL container: %v", err)
|
||||
|
|
@ -167,7 +183,7 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
ctx := context.Background()
|
||||
|
||||
postgresOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
|
@ -183,7 +199,7 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("5432/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
network.WithNetwork([]string{postgresNetworkAlias}, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start PostgreSQL container: %v", err)
|
||||
|
|
@ -264,6 +280,11 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
"MEMOS_MODE": "prod",
|
||||
}
|
||||
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []testcontainers.ContainerCustomizer
|
||||
|
||||
switch cfg.Driver {
|
||||
|
|
@ -272,6 +293,12 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
|
||||
}))
|
||||
case "mysql", "postgres":
|
||||
if cfg.DSN == "" {
|
||||
return nil, errors.Errorf("dsn is required for %s migration testing", cfg.Driver)
|
||||
}
|
||||
env["MEMOS_DRIVER"] = cfg.Driver
|
||||
env["MEMOS_DSN"] = cfg.DSN
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
|
||||
}
|
||||
|
|
@ -303,6 +330,7 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
}
|
||||
|
||||
// Apply options
|
||||
opts = append(opts, network.WithNetwork(nil, nw))
|
||||
for _, opt := range opts {
|
||||
if err := opt.Customize(&genericReq); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to apply container option")
|
||||
|
|
@ -316,3 +344,27 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
|
||||
return ctr, nil
|
||||
}
|
||||
|
||||
func getContainerDSN(driver, hostDSN string) (string, error) {
|
||||
switch driver {
|
||||
case "mysql":
|
||||
cfg, err := mysqldriver.ParseDSN(hostDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse mysql dsn")
|
||||
}
|
||||
cfg.Net = "tcp"
|
||||
cfg.Addr = net.JoinHostPort(mysqlNetworkAlias, "3306")
|
||||
return cfg.FormatDSN(), nil
|
||||
case "postgres":
|
||||
u, err := url.Parse(hostDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse postgres dsn")
|
||||
}
|
||||
u.Host = net.JoinHostPort(postgresNetworkAlias, "5432")
|
||||
return u.String(), nil
|
||||
case "sqlite":
|
||||
return hostDSN, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported driver for container dsn: %s", driver)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -257,6 +257,34 @@ func TestInstanceSettingTagsSetting(t *testing.T) {
|
|||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingTagsSettingWithoutColor(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_TAGS,
|
||||
Value: &storepb.InstanceSetting_TagsSetting{
|
||||
TagsSetting: &storepb.InstanceTagsSetting{
|
||||
Tags: map[string]*storepb.InstanceTagMetadata{
|
||||
"spoiler": {
|
||||
BlurContent: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tagsSetting, err := ts.GetInstanceTagsSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, tagsSetting.Tags, "spoiler")
|
||||
require.Nil(t, tagsSetting.Tags["spoiler"].GetBackgroundColor())
|
||||
require.True(t, tagsSetting.Tags["spoiler"].GetBlurContent())
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingNotificationSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
|
|
|||
|
|
@ -730,6 +730,31 @@ func TestMemoFilterTagsExistsContains(t *testing.T) {
|
|||
require.Len(t, memos, 1, "Should find 1 non-todo memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-1231", tc.User.ID).
|
||||
Content("Memo with exact numeric tag").
|
||||
Tags("1231", "project"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-1231-suffix", tc.User.ID).
|
||||
Content("Memo with related tag").
|
||||
Tags("tag/1231", "other"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-other", tc.User.ID).
|
||||
Content("Memo with different tag").
|
||||
Tags("9999"))
|
||||
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t == "1231")`)
|
||||
require.Len(t, memos, 1, "Should find only the memo with exact matching tag")
|
||||
require.Equal(t, "memo-1231", memos[0].UID)
|
||||
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t == "1231")`)
|
||||
require.Len(t, memos, 2, "Should exclude only the memo with exact matching tag")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestMigrationFromV0262PreservesLegacyData(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping container-based upgrade test in short mode")
|
||||
}
|
||||
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
|
||||
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
driver := getDriverFromEnv()
|
||||
|
||||
cfg, hostDSN := prepareV0262MigrationTest(t, driver)
|
||||
t.Logf("Starting Memos %s container for %s schema bootstrap...", cfg.Version, driver)
|
||||
container, err := StartMemosContainer(ctx, cfg)
|
||||
require.NoError(t, err, "failed to start v0.26.2 memos container")
|
||||
t.Cleanup(func() {
|
||||
if container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
})
|
||||
|
||||
legacyStore := NewTestingStoreWithDSN(ctx, t, driver, hostDSN)
|
||||
require.Eventually(t, func() bool {
|
||||
setting, err := legacyStore.GetInstanceBasicSetting(ctx)
|
||||
return err == nil && setting != nil && setting.SchemaVersion != ""
|
||||
}, 45*time.Second, 500*time.Millisecond, "legacy schema should be initialized by old container")
|
||||
|
||||
settingBeforeSeed, err := legacyStore.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Legacy schema version before migration: %s", settingBeforeSeed.SchemaVersion)
|
||||
|
||||
err = container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to stop v0.26.2 memos container")
|
||||
container = nil
|
||||
|
||||
db := openMigrationSQLDB(t, driver, hostDSN)
|
||||
defer db.Close()
|
||||
|
||||
seedLegacyMigrationData(ctx, t, driver, db)
|
||||
|
||||
count, err := countSystemSetting(ctx, db, "STORAGE")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, count, "v0.26.2 database should not have a STORAGE setting before migration")
|
||||
|
||||
ts := NewTestingStoreWithDSN(ctx, t, driver, hostDSN)
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration from v0.26.2 should succeed for %s", driver)
|
||||
|
||||
currentVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
currentSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentVersion, currentSetting.SchemaVersion, "schema version should be updated")
|
||||
|
||||
storageSetting, err := ts.GetInstanceStorageSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, storepb.InstanceStorageSetting_DATABASE, storageSetting.StorageType, "existing installs should stay on DATABASE storage")
|
||||
|
||||
idps, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idps, 2)
|
||||
idpUIDsByName := map[string]string{}
|
||||
for _, idp := range idps {
|
||||
idpUIDsByName[idp.Name] = idp.Uid
|
||||
}
|
||||
require.Equal(t, "00000191", idpUIDsByName["Legacy Google"])
|
||||
require.Equal(t, "00000192", idpUIDsByName["Legacy GitHub"])
|
||||
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.NotNil(t, inboxes[0].Message)
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inboxes[0].Message.Type)
|
||||
require.Equal(t, int32(102), inboxes[0].Message.GetMemoComment().MemoId)
|
||||
require.Equal(t, int32(101), inboxes[0].Message.GetMemoComment().RelatedMemoId)
|
||||
|
||||
activityExists, err := tableExists(ctx, db, driver, "activity")
|
||||
require.NoError(t, err)
|
||||
require.False(t, activityExists, "activity table should be removed after migration")
|
||||
|
||||
memoShareExists, err := tableExists(ctx, db, driver, "memo_share")
|
||||
require.NoError(t, err)
|
||||
require.True(t, memoShareExists, "memo_share table should be created")
|
||||
|
||||
share, err := ts.CreateMemoShare(ctx, &store.MemoShare{
|
||||
UID: "post-upgrade-share",
|
||||
MemoID: 101,
|
||||
CreatorID: 11,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "post-upgrade-share", share.UID)
|
||||
|
||||
postUpgradeUser, err := createTestingUserWithRole(ctx, ts, "postupgrade", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
postUpgradeMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "post-upgrade-memo-v0262",
|
||||
CreatorID: postUpgradeUser.ID,
|
||||
Content: "created after v0.26.2 migration",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "created after v0.26.2 migration", postUpgradeMemo.Content)
|
||||
}
|
||||
|
||||
func prepareV0262MigrationTest(t *testing.T, driver string) (MemosContainerConfig, string) {
|
||||
t.Helper()
|
||||
|
||||
const version = "0.26.2"
|
||||
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
dataDir := t.TempDir()
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DataDir: dataDir,
|
||||
}, fmt.Sprintf("%s/memos_prod.db", dataDir)
|
||||
case "mysql":
|
||||
hostDSN := GetMySQLDSN(t)
|
||||
containerDSN, err := getContainerDSN(driver, hostDSN)
|
||||
require.NoError(t, err)
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DSN: containerDSN,
|
||||
}, hostDSN
|
||||
case "postgres":
|
||||
hostDSN := GetPostgresDSN(t)
|
||||
containerDSN, err := getContainerDSN(driver, hostDSN)
|
||||
require.NoError(t, err)
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DSN: containerDSN,
|
||||
}, hostDSN
|
||||
default:
|
||||
t.Fatalf("unsupported driver: %s", driver)
|
||||
return MemosContainerConfig{}, ""
|
||||
}
|
||||
}
|
||||
|
||||
func openMigrationSQLDB(t *testing.T, driver, dsn string) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open(driver, dsn)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Ping())
|
||||
return db
|
||||
}
|
||||
|
||||
func seedLegacyMigrationData(ctx context.Context, t *testing.T, driver string, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
execMigrationSQL(t, db, legacyInsertUserSQL(driver, 11, "owner"))
|
||||
execMigrationSQL(t, db, legacyInsertUserSQL(driver, 12, "commenter"))
|
||||
execMigrationSQL(t, db, legacyInsertMemoSQL(101, 11, "legacy-parent", "parent memo"))
|
||||
execMigrationSQL(t, db, legacyInsertMemoSQL(102, 12, "legacy-comment", "comment memo"))
|
||||
execMigrationSQL(t, db, legacyInsertActivitySQL(201, 12))
|
||||
execMigrationSQL(t, db, legacyInsertInboxSQL(301, 12, 11, 201))
|
||||
execMigrationSQL(t, db, legacyInsertIDPSQL(401, "Legacy Google"))
|
||||
execMigrationSQL(t, db, legacyInsertIDPSQL(402, "Legacy GitHub"))
|
||||
|
||||
var message string
|
||||
err := db.QueryRowContext(ctx, "SELECT message FROM inbox WHERE id = 301").Scan(&message)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, message, "\"activityId\":201")
|
||||
require.NotContains(t, message, "\"memoComment\"")
|
||||
}
|
||||
|
||||
func execMigrationSQL(t *testing.T, db *sql.DB, query string) {
|
||||
t.Helper()
|
||||
_, err := db.Exec(query)
|
||||
require.NoError(t, err, "failed to execute SQL: %s", query)
|
||||
}
|
||||
|
||||
func legacyInsertUserSQL(driver string, id int, username string) string {
|
||||
table := "user"
|
||||
switch driver {
|
||||
case "mysql":
|
||||
table = "`user`"
|
||||
case "postgres", "sqlite":
|
||||
table = `"user"`
|
||||
default:
|
||||
// Keep the unquoted fallback for unknown test drivers.
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO %s (id, username, role, email, nickname, password_hash, avatar_url, description) VALUES (%d, '%s', 'USER', '%s@example.com', '%s', 'legacy-hash', '', 'legacy user')",
|
||||
table, id, username, username, username,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertMemoSQL(id, creatorID int, uid, content string) string {
|
||||
payload := "{}"
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO memo (id, uid, creator_id, content, visibility, payload) VALUES (%d, '%s', %d, '%s', 'PRIVATE', '%s')",
|
||||
id, uid, creatorID, content, payload,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertActivitySQL(id, creatorID int) string {
|
||||
payload := `{"memoComment":{"memoId":102,"relatedMemoId":101}}`
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO activity (id, creator_id, type, level, payload) VALUES (%d, %d, 'MEMO_COMMENT', 'INFO', '%s')",
|
||||
id, creatorID, payload,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertInboxSQL(id, senderID, receiverID, activityID int) string {
|
||||
message := fmt.Sprintf(`{"type":"MEMO_COMMENT","activityId":%d}`, activityID)
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO inbox (id, sender_id, receiver_id, status, message) VALUES (%d, %d, %d, 'UNREAD', '%s')",
|
||||
id, senderID, receiverID, message,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertIDPSQL(id int, name string) string {
|
||||
config := `{"clientId":"legacy-client","clientSecret":"legacy-secret","authUrl":"https://example.com/auth","tokenUrl":"https://example.com/token","userInfoUrl":"https://example.com/userinfo"}`
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO idp (id, name, type, identifier_filter, config) VALUES (%d, '%s', 'OAUTH2', '', '%s')",
|
||||
id, name, config,
|
||||
)
|
||||
}
|
||||
|
||||
func countSystemSetting(ctx context.Context, db *sql.DB, name string) (int, error) {
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = ?", name).Scan(&count)
|
||||
if err == nil {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = $1", name).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, driver, table string) (bool, error) {
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
var name string
|
||||
err := db.QueryRowContext(ctx, "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
case "mysql":
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", table).Scan(&count)
|
||||
return count > 0, err
|
||||
case "postgres":
|
||||
var regclass sql.NullString
|
||||
err := db.QueryRowContext(ctx, "SELECT to_regclass($1)", "public."+table).Scan(®class)
|
||||
return regclass.Valid && strings.EqualFold(regclass.String, table), err
|
||||
default:
|
||||
return false, errors.Errorf("unsupported driver: %s", driver)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,25 +1,49 @@
|
|||
import { debounce } from "lodash-es";
|
||||
|
||||
export const CACHE_DEBOUNCE_DELAY = 500;
|
||||
|
||||
const pendingSaves = new Map<string, ReturnType<typeof window.setTimeout>>();
|
||||
|
||||
export const cacheService = {
|
||||
key: (username: string, cacheKey?: string): string => {
|
||||
return `${username}-${cacheKey || ""}`;
|
||||
},
|
||||
|
||||
save: debounce((key: string, content: string) => {
|
||||
if (content.trim()) {
|
||||
localStorage.setItem(key, content);
|
||||
} else {
|
||||
localStorage.removeItem(key);
|
||||
save: (key: string, content: string) => {
|
||||
const pendingSave = pendingSaves.get(key);
|
||||
if (pendingSave) {
|
||||
window.clearTimeout(pendingSave);
|
||||
}
|
||||
}, CACHE_DEBOUNCE_DELAY),
|
||||
|
||||
const timeoutId = window.setTimeout(() => {
|
||||
pendingSaves.delete(key);
|
||||
|
||||
if (content.trim()) {
|
||||
localStorage.setItem(key, content);
|
||||
} else {
|
||||
localStorage.removeItem(key);
|
||||
}
|
||||
}, CACHE_DEBOUNCE_DELAY);
|
||||
|
||||
pendingSaves.set(key, timeoutId);
|
||||
},
|
||||
|
||||
load(key: string): string {
|
||||
return localStorage.getItem(key) || "";
|
||||
},
|
||||
|
||||
clear(key: string): void {
|
||||
const pendingSave = pendingSaves.get(key);
|
||||
if (pendingSave) {
|
||||
window.clearTimeout(pendingSave);
|
||||
pendingSaves.delete(key);
|
||||
}
|
||||
|
||||
localStorage.removeItem(key);
|
||||
},
|
||||
|
||||
clearAll(): void {
|
||||
for (const timeoutId of pendingSaves.values()) {
|
||||
window.clearTimeout(timeoutId);
|
||||
}
|
||||
pendingSaves.clear();
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,42 +1,20 @@
|
|||
import { FileAudioIcon, FileIcon, PaperclipIcon } from "lucide-react";
|
||||
import { FileIcon, PaperclipIcon } from "lucide-react";
|
||||
import { useMemo } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { Attachment } from "@/types/proto/api/v1/attachment_service_pb";
|
||||
import { getAttachmentType, getAttachmentUrl } from "@/utils/attachment";
|
||||
import { formatFileSize, getFileTypeLabel } from "@/utils/format";
|
||||
import { getAttachmentUrl } from "@/utils/attachment";
|
||||
import SectionHeader from "../SectionHeader";
|
||||
import AttachmentCard from "./AttachmentCard";
|
||||
import AudioAttachmentItem from "./AudioAttachmentItem";
|
||||
import { getAttachmentMetadata, isImageAttachment, separateAttachments } from "./attachmentViewHelpers";
|
||||
|
||||
interface AttachmentListViewProps {
|
||||
attachments: Attachment[];
|
||||
onImagePreview?: (urls: string[], index: number) => void;
|
||||
}
|
||||
|
||||
const isImageAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "image/*";
|
||||
const isVideoAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "video/*";
|
||||
const isAudioAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "audio/*";
|
||||
|
||||
const separateAttachments = (attachments: Attachment[]) => {
|
||||
const visual: Attachment[] = [];
|
||||
const audio: Attachment[] = [];
|
||||
const docs: Attachment[] = [];
|
||||
|
||||
for (const attachment of attachments) {
|
||||
if (isImageAttachment(attachment) || isVideoAttachment(attachment)) {
|
||||
visual.push(attachment);
|
||||
} else if (isAudioAttachment(attachment)) {
|
||||
audio.push(attachment);
|
||||
} else {
|
||||
docs.push(attachment);
|
||||
}
|
||||
}
|
||||
|
||||
return { visual, audio, docs };
|
||||
};
|
||||
|
||||
const DocumentItem = ({ attachment }: { attachment: Attachment }) => {
|
||||
const fileTypeLabel = getFileTypeLabel(attachment.type);
|
||||
const fileSizeLabel = attachment.size ? formatFileSize(Number(attachment.size)) : undefined;
|
||||
const { fileTypeLabel, fileSizeLabel } = getAttachmentMetadata(attachment);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-1 px-1 py-1 rounded text-xs text-muted-foreground hover:text-foreground hover:bg-accent/20 transition-colors whitespace-nowrap">
|
||||
|
|
@ -62,22 +40,6 @@ const DocumentItem = ({ attachment }: { attachment: Attachment }) => {
|
|||
);
|
||||
};
|
||||
|
||||
const AudioItem = ({ attachment }: { attachment: Attachment }) => {
|
||||
const sourceUrl = getAttachmentUrl(attachment);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1 px-1 py-1">
|
||||
<div className="flex items-center gap-1 text-xs text-muted-foreground">
|
||||
<FileAudioIcon className="w-3 h-3 shrink-0" />
|
||||
<span className="truncate" title={attachment.filename}>
|
||||
{attachment.filename}
|
||||
</span>
|
||||
</div>
|
||||
<audio src={sourceUrl} controls preload="metadata" className="w-full h-8" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface VisualItemProps {
|
||||
attachment: Attachment;
|
||||
onImageClick?: (url: string) => void;
|
||||
|
|
@ -114,9 +76,9 @@ const VisualGrid = ({ attachments, onImageClick }: { attachments: Attachment[];
|
|||
);
|
||||
|
||||
const AudioList = ({ attachments }: { attachments: Attachment[] }) => (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex flex-col gap-2">
|
||||
{attachments.map((attachment) => (
|
||||
<AudioItem key={attachment.name} attachment={attachment} />
|
||||
<AudioAttachmentItem key={attachment.name} attachment={attachment} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,178 @@
|
|||
import { FileAudioIcon, PauseIcon, PlayIcon } from "lucide-react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import type { Attachment } from "@/types/proto/api/v1/attachment_service_pb";
|
||||
import { getAttachmentUrl } from "@/utils/attachment";
|
||||
import { formatAudioTime, getAttachmentMetadata } from "./attachmentViewHelpers";
|
||||
|
||||
const AUDIO_PLAYBACK_RATES = [1, 1.5, 2] as const;
|
||||
|
||||
interface AudioProgressBarProps {
|
||||
attachment: Attachment;
|
||||
currentTime: number;
|
||||
duration: number;
|
||||
progressPercent: number;
|
||||
onSeek: (value: string) => void;
|
||||
}
|
||||
|
||||
const AudioProgressBar = ({ attachment, currentTime, duration, progressPercent, onSeek }: AudioProgressBarProps) => (
|
||||
<div className="mt-2 flex items-center gap-2.5">
|
||||
<div className="relative flex h-4 min-w-0 flex-1 items-center">
|
||||
<div className="absolute inset-x-0 h-1 rounded-full bg-muted/75" />
|
||||
<div className="absolute left-0 h-1 rounded-full bg-foreground/20" style={{ width: `${Math.min(progressPercent, 100)}%` }} />
|
||||
<input
|
||||
type="range"
|
||||
min={0}
|
||||
max={duration || 1}
|
||||
step={0.1}
|
||||
value={Math.min(currentTime, duration || 0)}
|
||||
onChange={(e) => onSeek(e.target.value)}
|
||||
aria-label={`Seek ${attachment.filename}`}
|
||||
className="relative z-10 h-4 w-full cursor-pointer appearance-none bg-transparent outline-none disabled:cursor-default
|
||||
[&::-webkit-slider-runnable-track]:h-1 [&::-webkit-slider-runnable-track]:rounded-full
|
||||
[&::-webkit-slider-runnable-track]:bg-transparent
|
||||
[&::-webkit-slider-thumb]:mt-[-3px] [&::-webkit-slider-thumb]:size-2 [&::-webkit-slider-thumb]:appearance-none
|
||||
[&::-webkit-slider-thumb]:rounded-full [&::-webkit-slider-thumb]:border [&::-webkit-slider-thumb]:border-border/50
|
||||
[&::-webkit-slider-thumb]:bg-background/95
|
||||
[&::-moz-range-track]:h-1 [&::-moz-range-track]:rounded-full [&::-moz-range-track]:bg-transparent
|
||||
[&::-moz-range-thumb]:size-2 [&::-moz-range-thumb]:rounded-full [&::-moz-range-thumb]:border
|
||||
[&::-moz-range-thumb]:border-border/50 [&::-moz-range-thumb]:bg-background/95"
|
||||
disabled={duration === 0}
|
||||
/>
|
||||
</div>
|
||||
<div className="shrink-0 text-[11px] tabular-nums text-muted-foreground">
|
||||
{formatAudioTime(currentTime)} / {duration > 0 ? formatAudioTime(duration) : "--:--"}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const AudioAttachmentItem = ({ attachment }: { attachment: Attachment }) => {
|
||||
const sourceUrl = getAttachmentUrl(attachment);
|
||||
const audioRef = useRef<HTMLAudioElement>(null);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [currentTime, setCurrentTime] = useState(0);
|
||||
const [duration, setDuration] = useState(0);
|
||||
const [playbackRate, setPlaybackRate] = useState<(typeof AUDIO_PLAYBACK_RATES)[number]>(1);
|
||||
const { fileTypeLabel, fileSizeLabel } = getAttachmentMetadata(attachment);
|
||||
const progressPercent = duration > 0 ? (currentTime / duration) * 100 : 0;
|
||||
|
||||
useEffect(() => {
|
||||
if (!audioRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
audioRef.current.playbackRate = playbackRate;
|
||||
}, [playbackRate]);
|
||||
|
||||
const togglePlayback = async () => {
|
||||
const audio = audioRef.current;
|
||||
|
||||
if (!audio) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (audio.paused) {
|
||||
try {
|
||||
await audio.play();
|
||||
} catch {
|
||||
setIsPlaying(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
};
|
||||
|
||||
const handleSeek = (value: string) => {
|
||||
const audio = audioRef.current;
|
||||
const nextTime = Number(value);
|
||||
|
||||
if (!audio || Number.isNaN(nextTime)) {
|
||||
return;
|
||||
}
|
||||
|
||||
audio.currentTime = nextTime;
|
||||
setCurrentTime(nextTime);
|
||||
};
|
||||
|
||||
const handlePlaybackRateChange = () => {
|
||||
const currentRateIndex = AUDIO_PLAYBACK_RATES.findIndex((rate) => rate === playbackRate);
|
||||
const nextRate = AUDIO_PLAYBACK_RATES[(currentRateIndex + 1) % AUDIO_PLAYBACK_RATES.length];
|
||||
setPlaybackRate(nextRate);
|
||||
};
|
||||
|
||||
const handleDuration = (value: number) => {
|
||||
setDuration(Number.isFinite(value) ? value : 0);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border border-border/35 bg-background/70 px-2.5 py-2.5">
|
||||
<div className="flex items-start gap-2.5">
|
||||
<div className="mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-lg bg-muted/55 text-muted-foreground">
|
||||
<FileAudioIcon className="size-3.5" />
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 items-start justify-between gap-3">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="truncate text-sm font-medium leading-5 text-foreground" title={attachment.filename}>
|
||||
{attachment.filename}
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center gap-x-1.5 gap-y-0.5 text-xs leading-4 text-muted-foreground">
|
||||
<span>{fileTypeLabel}</span>
|
||||
{fileSizeLabel && (
|
||||
<>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
<span>{fileSizeLabel}</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-0.5 flex shrink-0 items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={handlePlaybackRateChange}
|
||||
className="inline-flex h-6 items-center justify-center px-1 text-[11px] font-medium text-muted-foreground transition-colors hover:text-foreground"
|
||||
aria-label={`Playback speed ${playbackRate}x for ${attachment.filename}`}
|
||||
>
|
||||
{playbackRate}x
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={togglePlayback}
|
||||
className="inline-flex size-6.5 items-center justify-center rounded-md border border-border/45 bg-background/85 text-foreground transition-colors hover:bg-muted/45"
|
||||
aria-label={isPlaying ? `Pause ${attachment.filename}` : `Play ${attachment.filename}`}
|
||||
>
|
||||
{isPlaying ? <PauseIcon className="size-3" /> : <PlayIcon className="size-3 translate-x-[0.5px]" />}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<AudioProgressBar
|
||||
attachment={attachment}
|
||||
currentTime={currentTime}
|
||||
duration={duration}
|
||||
progressPercent={progressPercent}
|
||||
onSeek={handleSeek}
|
||||
/>
|
||||
|
||||
<audio
|
||||
ref={audioRef}
|
||||
src={sourceUrl}
|
||||
preload="metadata"
|
||||
className="hidden"
|
||||
onLoadedMetadata={(e) => handleDuration(e.currentTarget.duration)}
|
||||
onDurationChange={(e) => handleDuration(e.currentTarget.duration)}
|
||||
onTimeUpdate={(e) => setCurrentTime(e.currentTarget.currentTime)}
|
||||
onPlay={() => setIsPlaying(true)}
|
||||
onPause={() => setIsPlaying(false)}
|
||||
onEnded={() => {
|
||||
setIsPlaying(false);
|
||||
setCurrentTime(0);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default AudioAttachmentItem;
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import type { Attachment } from "@/types/proto/api/v1/attachment_service_pb";
|
||||
import { getAttachmentType } from "@/utils/attachment";
|
||||
import { formatFileSize, getFileTypeLabel } from "@/utils/format";
|
||||
|
||||
export interface AttachmentGroups {
|
||||
visual: Attachment[];
|
||||
audio: Attachment[];
|
||||
docs: Attachment[];
|
||||
}
|
||||
|
||||
export interface AttachmentMetadata {
|
||||
fileTypeLabel: string;
|
||||
fileSizeLabel?: string;
|
||||
}
|
||||
|
||||
export const isImageAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "image/*";
|
||||
export const isVideoAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "video/*";
|
||||
export const isAudioAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "audio/*";
|
||||
|
||||
export const separateAttachments = (attachments: Attachment[]): AttachmentGroups => {
|
||||
const groups: AttachmentGroups = {
|
||||
visual: [],
|
||||
audio: [],
|
||||
docs: [],
|
||||
};
|
||||
|
||||
for (const attachment of attachments) {
|
||||
if (isImageAttachment(attachment) || isVideoAttachment(attachment)) {
|
||||
groups.visual.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isAudioAttachment(attachment)) {
|
||||
groups.audio.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
groups.docs.push(attachment);
|
||||
}
|
||||
|
||||
return groups;
|
||||
};
|
||||
|
||||
export const getAttachmentMetadata = (attachment: Attachment): AttachmentMetadata => ({
|
||||
fileTypeLabel: getFileTypeLabel(attachment.type),
|
||||
fileSizeLabel: attachment.size ? formatFileSize(Number(attachment.size)) : undefined,
|
||||
});
|
||||
|
||||
export const formatAudioTime = (seconds: number): string => {
|
||||
if (!Number.isFinite(seconds) || seconds < 0) {
|
||||
return "0:00";
|
||||
}
|
||||
|
||||
const rounded = Math.floor(seconds);
|
||||
const hours = Math.floor(rounded / 3600);
|
||||
const minutes = Math.floor((rounded % 3600) / 60);
|
||||
const secs = rounded % 60;
|
||||
|
||||
if (hours > 0) {
|
||||
return `${hours}:${minutes.toString().padStart(2, "0")}:${secs.toString().padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
return `${minutes}:${secs.toString().padStart(2, "0")}`;
|
||||
};
|
||||
|
|
@ -22,8 +22,7 @@ import SettingGroup from "./SettingGroup";
|
|||
import SettingSection from "./SettingSection";
|
||||
import SettingTable from "./SettingTable";
|
||||
|
||||
// Fallback to white when no color is stored.
|
||||
const tagColorToHex = (color?: { red?: number; green?: number; blue?: number }): string => colorToHex(color) ?? "#ffffff";
|
||||
const DEFAULT_TAG_COLOR = "#ffffff";
|
||||
|
||||
// Converts a CSS hex string to a google.type.Color message.
|
||||
const hexToColor = (hex: string) =>
|
||||
|
|
@ -33,24 +32,36 @@ const hexToColor = (hex: string) =>
|
|||
blue: parseInt(hex.slice(5, 7), 16) / 255,
|
||||
});
|
||||
|
||||
interface LocalTagMeta {
|
||||
color?: string;
|
||||
blur: boolean;
|
||||
}
|
||||
|
||||
const toLocalTagMeta = (meta: {
|
||||
backgroundColor?: { red?: number; green?: number; blue?: number };
|
||||
blurContent: boolean;
|
||||
}): LocalTagMeta => ({
|
||||
color: colorToHex(meta.backgroundColor),
|
||||
blur: meta.blurContent,
|
||||
});
|
||||
|
||||
const TagsSection = () => {
|
||||
const t = useTranslate();
|
||||
const { tagsSetting: originalSetting, updateSetting, fetchSetting } = useInstance();
|
||||
const { data: tagCounts = {} } = useTagCounts(false);
|
||||
|
||||
// Local state: map of tagName → hex color string for editing.
|
||||
const [localTags, setLocalTags] = useState<Record<string, string>>(() =>
|
||||
Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, tagColorToHex(meta.backgroundColor)])),
|
||||
// Local state: map of tagName → { color, blur } for editing.
|
||||
const [localTags, setLocalTags] = useState<Record<string, LocalTagMeta>>(() =>
|
||||
Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])),
|
||||
);
|
||||
const [newTagName, setNewTagName] = useState("");
|
||||
const [newTagColor, setNewTagColor] = useState("#ffffff");
|
||||
const [newTagColor, setNewTagColor] = useState<string | undefined>(undefined);
|
||||
const [newTagBlur, setNewTagBlur] = useState(false);
|
||||
|
||||
// Sync local state when the fetched setting arrives (the fetch is async and
|
||||
// completes after mount, so localTags would be empty without this sync).
|
||||
useEffect(() => {
|
||||
setLocalTags(
|
||||
Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, tagColorToHex(meta.backgroundColor)])),
|
||||
);
|
||||
setLocalTags(Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])));
|
||||
}, [originalSetting.tags]);
|
||||
|
||||
// All known tag names: union of saved entries and tags used in memos.
|
||||
|
|
@ -68,8 +79,8 @@ const TagsSection = () => {
|
|||
[localTags],
|
||||
);
|
||||
|
||||
const originalHexMap = useMemo(
|
||||
() => Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, tagColorToHex(meta.backgroundColor)])),
|
||||
const originalMetaMap = useMemo(
|
||||
() => Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])),
|
||||
[originalSetting.tags],
|
||||
);
|
||||
const hasChanges = !isEqual(localTags, originalHexMap);
|
||||
|
|
@ -78,6 +89,10 @@ const TagsSection = () => {
|
|||
setLocalTags((prev) => ({ ...prev, [tagName]: hex }));
|
||||
};
|
||||
|
||||
const handleClearColor = (tagName: string) => {
|
||||
setLocalTags((prev) => ({ ...prev, [tagName]: { ...prev[tagName], color: undefined } }));
|
||||
};
|
||||
|
||||
const handleRemoveTag = (tagName: string) => {
|
||||
setLocalTags((prev) => {
|
||||
const next = { ...prev };
|
||||
|
|
@ -99,7 +114,8 @@ const TagsSection = () => {
|
|||
}
|
||||
setLocalTags((prev) => ({ ...prev, [name]: newTagColor }));
|
||||
setNewTagName("");
|
||||
setNewTagColor("#ffffff");
|
||||
setNewTagColor(undefined);
|
||||
setNewTagBlur(false);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
|
|
@ -107,7 +123,10 @@ const TagsSection = () => {
|
|||
const tags = Object.fromEntries(
|
||||
Object.entries(localTags).map(([name, hex]) => [
|
||||
name,
|
||||
create(InstanceSetting_TagMetadataSchema, { backgroundColor: hexToColor(hex) }),
|
||||
create(InstanceSetting_TagMetadataSchema, {
|
||||
blurContent: meta.blur,
|
||||
...(meta.color ? { backgroundColor: hexToColor(meta.color) } : {}),
|
||||
}),
|
||||
]),
|
||||
);
|
||||
await updateSetting(
|
||||
|
|
@ -144,9 +163,15 @@ const TagsSection = () => {
|
|||
<input
|
||||
type="color"
|
||||
className="w-8 h-8 cursor-pointer rounded border border-border bg-transparent p-0.5"
|
||||
value={localTags[row.name]}
|
||||
value={localTags[row.name].color ?? DEFAULT_TAG_COLOR}
|
||||
onChange={(e) => handleColorChange(row.name, e.target.value)}
|
||||
/>
|
||||
<Button variant="ghost" size="sm" onClick={() => handleClearColor(row.name)} disabled={!localTags[row.name].color}>
|
||||
{t("common.clear")}
|
||||
</Button>
|
||||
{!localTags[row.name].color && (
|
||||
<span className="text-xs text-muted-foreground">{t("setting.tags.using-default-color")}</span>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
|
|
@ -185,15 +210,28 @@ const TagsSection = () => {
|
|||
<input
|
||||
type="color"
|
||||
className="w-8 h-8 cursor-pointer rounded border border-border bg-transparent p-0.5"
|
||||
value={newTagColor}
|
||||
value={newTagColor ?? DEFAULT_TAG_COLOR}
|
||||
onChange={(e) => setNewTagColor(e.target.value)}
|
||||
/>
|
||||
<Button variant="ghost" size="sm" onClick={() => setNewTagColor(undefined)} disabled={!newTagColor}>
|
||||
{t("common.clear")}
|
||||
</Button>
|
||||
<label className="flex items-center gap-1.5 text-sm text-muted-foreground">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="w-4 h-4 cursor-pointer"
|
||||
checked={newTagBlur}
|
||||
onChange={(e) => setNewTagBlur(e.target.checked)}
|
||||
/>
|
||||
{t("setting.tags.blur-content")}
|
||||
</label>
|
||||
<Button variant="outline" onClick={handleAddTag} disabled={!newTagName.trim()}>
|
||||
<PlusIcon className="w-4 h-4 mr-1.5" />
|
||||
{t("common.add")}
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground mt-1">{t("setting.tags.tag-pattern-hint")}</p>
|
||||
{!newTagColor && <p className="text-xs text-muted-foreground">{t("setting.tags.using-default-color")}</p>}
|
||||
</SettingGroup>
|
||||
|
||||
<div className="w-full flex justify-end">
|
||||
|
|
|
|||
|
|
@ -474,14 +474,15 @@
|
|||
"tags": {
|
||||
"label": "Tags",
|
||||
"title": "Tag metadata",
|
||||
"description": "Assign display colors to tags instance-wide. Tag names are treated as anchored regex patterns.",
|
||||
"description": "Assign optional display colors to tags instance-wide, or blur matching memo content. Tag names are treated as anchored regex patterns.",
|
||||
"background-color": "Background color",
|
||||
"no-tags-configured": "No tag metadata configured.",
|
||||
"tag-name": "Tag name",
|
||||
"tag-name-placeholder": "e.g. work or project/.*",
|
||||
"tag-already-exists": "Tag already exists.",
|
||||
"tag-pattern-hint": "Tag name or regex pattern (e.g. project/.* matches all project/ tags)",
|
||||
"invalid-regex": "Invalid or unsafe regex pattern."
|
||||
"invalid-regex": "Invalid or unsafe regex pattern.",
|
||||
"using-default-color": "Using default color."
|
||||
}
|
||||
},
|
||||
"tag": {
|
||||
|
|
|
|||
|
|
@ -414,7 +414,8 @@ export const InstanceSetting_MemoRelatedSettingSchema: GenMessage<InstanceSettin
|
|||
*/
|
||||
export type InstanceSetting_TagMetadata = Message<"memos.api.v1.InstanceSetting.TagMetadata"> & {
|
||||
/**
|
||||
* Background color for the tag label.
|
||||
* Optional background color for the tag label.
|
||||
* When unset, the default tag color is used.
|
||||
*
|
||||
* @generated from field: google.type.Color background_color = 1;
|
||||
*/
|
||||
|
|
|
|||
Loading…
Reference in New Issue