diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index ee82ea22e..f563dde68 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -126,7 +126,11 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR } } - memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments) + relations, err := s.loadMemoRelations(ctx, memo) + if err != nil { + return nil, errors.Wrap(err, "failed to load memo relations") + } + memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments, relations) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -266,12 +270,19 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment) } + // RELATIONS (batch load to avoid N+1) + relationMap, err := s.batchConvertMemoRelations(ctx, memos) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to batch load memo relations") + } + for _, memo := range memos { memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) reactions := reactionMap[memoName] attachments := attachmentMap[memo.ID] + relations := relationMap[memo.ID] - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments) + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -327,7 +338,11 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest return nil, status.Errorf(codes.Internal, "failed to list attachments") } - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments) + relations, err := s.loadMemoRelations(ctx, memo) + if err != nil { + return nil, errors.Wrap(err, "failed to load memo relations") + } + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -462,7 +477,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR return nil, status.Errorf(codes.Internal, "failed to list attachments") } - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments) + relations, err := s.loadMemoRelations(ctx, memo) + if err != nil { + return nil, errors.Wrap(err, "failed to load memo relations") + } + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -521,7 +540,8 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR return nil, status.Errorf(codes.Internal, "failed to list attachments") } - if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil { + deleteRelations, _ := s.loadMemoRelations(ctx, memo) + if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, deleteRelations); err == nil { // Try to dispatch webhook when memo is deleted. if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil { slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err)) @@ -725,13 +745,20 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment) } + // RELATIONS (batch load to avoid N+1) + relationMap, err := s.batchConvertMemoRelations(ctx, memos) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to batch load memo relations") + } + var memosResponse []*v1pb.Memo for _, m := range memos { memoName := memoIDToNameMap[m.ID] reactions := memoReactionsMap[memoName] attachments := attachmentMap[m.ID] + relations := relationMap[m.ID] - memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments) + memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments, relations) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } diff --git a/server/router/api/v1/memo_service_converter.go b/server/router/api/v1/memo_service_converter.go index 325c1a0b2..fb585c1dc 100644 --- a/server/router/api/v1/memo_service_converter.go +++ b/server/router/api/v1/memo_service_converter.go @@ -13,7 +13,7 @@ import ( "github.com/usememos/memos/store" ) -func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) { +func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation) (*v1pb.Memo, error) { displayTs := memo.CreatedTs instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx) if err != nil { @@ -47,20 +47,18 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem } memoMessage.Reactions = []*v1pb.Reaction{} - for _, reaction := range reactions { reactionResponse := convertReactionFromStore(reaction) memoMessage.Reactions = append(memoMessage.Reactions, reactionResponse) } - listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name}) - if err != nil { - return nil, errors.Wrap(err, "failed to list memo relations") + if relations != nil { + memoMessage.Relations = relations + } else { + memoMessage.Relations = []*v1pb.MemoRelation{} } - memoMessage.Relations = listMemoRelationsResponse.Relations memoMessage.Attachments = []*v1pb.Attachment{} - for _, attachment := range attachments { attachmentResponse := convertAttachmentFromStore(attachment) memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse) @@ -75,6 +73,116 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem return memoMessage, nil } +// batchConvertMemoRelations batch-loads relations for a list of memos and returns +// a map from memo ID to its converted relations. This avoids N+1 queries when listing memos. +func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*store.Memo) (map[int32][]*v1pb.MemoRelation, error) { + if len(memos) == 0 { + return map[int32][]*v1pb.MemoRelation{}, nil + } + + currentUser, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get user") + } + var memoFilter string + if currentUser == nil { + memoFilter = `visibility == "PUBLIC"` + } else { + memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID) + } + + memoIDs := make([]int32, len(memos)) + memoIDSet := make(map[int32]bool, len(memos)) + for i, m := range memos { + memoIDs[i] = m.ID + memoIDSet[m.ID] = true + } + + // Single batch query to get all relations involving any of these memos. + allRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: memoIDs, + MemoFilter: &memoFilter, + }) + if err != nil { + return nil, errors.Wrap(err, "failed to batch list memo relations") + } + + // Collect all memo IDs referenced in relations that we need to resolve. + neededIDs := make(map[int32]bool) + for _, r := range allRelations { + neededIDs[r.MemoID] = true + neededIDs[r.RelatedMemoID] = true + } + + // Build ID→UID map from the memos we already have. + memoIDToUID := make(map[int32]string, len(memos)) + memoIDToContent := make(map[int32]string, len(memos)) + for _, m := range memos { + memoIDToUID[m.ID] = m.UID + memoIDToContent[m.ID] = m.Content + delete(neededIDs, m.ID) + } + + // Batch fetch any additional memos referenced by relations that we don't already have. + if len(neededIDs) > 0 { + extraIDs := make([]int32, 0, len(neededIDs)) + for id := range neededIDs { + extraIDs = append(extraIDs, id) + } + extraMemos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: extraIDs}) + if err != nil { + return nil, errors.Wrap(err, "failed to batch fetch related memos") + } + for _, m := range extraMemos { + memoIDToUID[m.ID] = m.UID + memoIDToContent[m.ID] = m.Content + } + } + + // Build the result map: memo ID → its relations (both directions). + result := make(map[int32][]*v1pb.MemoRelation, len(memos)) + for _, r := range allRelations { + memoUID, ok1 := memoIDToUID[r.MemoID] + relatedUID, ok2 := memoIDToUID[r.RelatedMemoID] + if !ok1 || !ok2 { + continue + } + + memoSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.MemoID]) + relatedSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.RelatedMemoID]) + relation := &v1pb.MemoRelation{ + Memo: &v1pb.MemoRelation_Memo{ + Name: fmt.Sprintf("%s%s", MemoNamePrefix, memoUID), + Snippet: memoSnippet, + }, + RelatedMemo: &v1pb.MemoRelation_Memo{ + Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedUID), + Snippet: relatedSnippet, + }, + Type: convertMemoRelationTypeFromStore(r.Type), + } + + // Add to the memo that owns this relation (both directions). + if memoIDSet[r.MemoID] { + result[r.MemoID] = append(result[r.MemoID], relation) + } + if memoIDSet[r.RelatedMemoID] { + result[r.RelatedMemoID] = append(result[r.RelatedMemoID], relation) + } + } + + return result, nil +} + +// loadMemoRelations loads relations for a single memo and converts them to API format. +func (s *APIV1Service) loadMemoRelations(ctx context.Context, memo *store.Memo) ([]*v1pb.MemoRelation, error) { + relationMap, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}) + if err != nil { + return nil, err + } + return relationMap[memo.ID], nil +} + func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property { if property == nil { return nil diff --git a/server/router/api/v1/test/test_helper.go b/server/router/api/v1/test/test_helper.go index c3afdb38b..94b42b688 100644 --- a/server/router/api/v1/test/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -27,13 +27,15 @@ func NewTestService(t *testing.T) *TestService { // Create a test store with SQLite testStore := teststore.NewTestingStore(ctx, t) - // Create a test profile + // Create a test profile with a temp directory for file storage, + // so tests that create attachments don't leave artifacts in the source tree. testProfile := &profile.Profile{ Demo: true, Version: "test-1.0.0", InstanceURL: "http://localhost:8080", Driver: "sqlite", DSN: ":memory:", + Data: t.TempDir(), } // Create APIV1Service with nil grpcServer since we're testing direct calls diff --git a/store/db/mysql/memo_relation.go b/store/db/mysql/memo_relation.go index 71b73be6f..5b8c8391a 100644 --- a/store/db/mysql/memo_relation.go +++ b/store/db/mysql/memo_relation.go @@ -42,6 +42,18 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation if find.Type != nil { where, args = append(where, "`type` = ?"), append(args, find.Type) } + if len(find.MemoIDList) > 0 { + placeholders := make([]string, len(find.MemoIDList)) + for i, id := range find.MemoIDList { + placeholders[i] = "?" + args = append(args, id) + } + inClause := strings.Join(placeholders, ", ") + for _, id := range find.MemoIDList { + args = append(args, id) + } + where = append(where, fmt.Sprintf("(`memo_id` IN (%s) OR `related_memo_id` IN (%s))", inClause, inClause)) + } if find.MemoFilter != nil { engine, err := filter.DefaultEngine() if err != nil { diff --git a/store/db/postgres/memo_relation.go b/store/db/postgres/memo_relation.go index a2f2817c7..c22c5c786 100644 --- a/store/db/postgres/memo_relation.go +++ b/store/db/postgres/memo_relation.go @@ -49,6 +49,20 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation if find.Type != nil { where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type) } + if len(find.MemoIDList) > 0 { + memoPlaceholders := make([]string, len(find.MemoIDList)) + for i, id := range find.MemoIDList { + memoPlaceholders[i] = placeholder(len(args) + 1) + args = append(args, id) + } + relatedPlaceholders := make([]string, len(find.MemoIDList)) + for i, id := range find.MemoIDList { + relatedPlaceholders[i] = placeholder(len(args) + 1) + args = append(args, id) + } + where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", + strings.Join(memoPlaceholders, ", "), strings.Join(relatedPlaceholders, ", "))) + } if find.MemoFilter != nil { engine, err := filter.DefaultEngine() if err != nil { diff --git a/store/db/sqlite/memo_relation.go b/store/db/sqlite/memo_relation.go index 5eed62e74..f65d70118 100644 --- a/store/db/sqlite/memo_relation.go +++ b/store/db/sqlite/memo_relation.go @@ -49,6 +49,19 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation if find.Type != nil { where, args = append(where, "type = ?"), append(args, find.Type) } + if len(find.MemoIDList) > 0 { + placeholders := make([]string, len(find.MemoIDList)) + for i, id := range find.MemoIDList { + placeholders[i] = "?" + args = append(args, id) + } + inClause := strings.Join(placeholders, ", ") + // Duplicate args for the second IN clause. + for _, id := range find.MemoIDList { + args = append(args, id) + } + where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, inClause)) + } if find.MemoFilter != nil { engine, err := filter.DefaultEngine() if err != nil { diff --git a/store/memo_relation.go b/store/memo_relation.go index 61b022884..0d63ae32a 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -24,6 +24,8 @@ type FindMemoRelation struct { RelatedMemoID *int32 Type *MemoRelationType MemoFilter *string + // MemoIDList matches relations where memo_id OR related_memo_id is in the list. + MemoIDList []int32 } type DeleteMemoRelation struct { diff --git a/store/test/memo_relation_test.go b/store/test/memo_relation_test.go index 9cfba6997..5abaaa3f0 100644 --- a/store/test/memo_relation_test.go +++ b/store/test/memo_relation_test.go @@ -638,6 +638,270 @@ func TestMemoRelationBidirectional(t *testing.T) { ts.Close() } +func TestMemoRelationListByMemoIDList(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create 3 memos. + memoA, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-a", + CreatorID: user.ID, + Content: "memo A content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoB, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-b", + CreatorID: user.ID, + Content: "memo B content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoC, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-c", + CreatorID: user.ID, + Content: "memo C content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoD, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-d", + CreatorID: user.ID, + Content: "memo D content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // A -> B (reference) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoA.ID, + RelatedMemoID: memoB.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // A -> C (comment) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoA.ID, + RelatedMemoID: memoC.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + // D -> B (reference) — B appears as related_memo_id + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoD.ID, + RelatedMemoID: memoB.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Batch query for memos A and B: should return all 3 relations + // (A->B because A is in list, A->C because A is in list, D->B because B is in list) + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memoA.ID, memoB.ID}, + }) + require.NoError(t, err) + require.Len(t, relations, 3) + + // Batch query for memo C only: should return 1 relation (A->C because C is related_memo_id) + relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memoC.ID}, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + require.Equal(t, memoA.ID, relations[0].MemoID) + require.Equal(t, memoC.ID, relations[0].RelatedMemoID) + + // Batch query for memo D only: should return 1 relation (D->B because D is memo_id) + relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memoD.ID}, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + require.Equal(t, memoD.ID, relations[0].MemoID) + require.Equal(t, memoB.ID, relations[0].RelatedMemoID) + + ts.Close() +} + +func TestMemoRelationListByMemoIDListEmpty(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + memo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-no-relations", + CreatorID: user.ID, + Content: "memo with no relations", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Batch query with a memo that has no relations. + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memo.ID}, + }) + require.NoError(t, err) + require.Len(t, relations, 0) + + // Empty MemoIDList should not filter by MemoIDList (returns based on other filters). + relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{}, + }) + require.NoError(t, err) + require.Len(t, relations, 0) + + ts.Close() +} + +func TestMemoRelationListByMemoIDListWithTypeFilter(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + memoA, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-a", + CreatorID: user.ID, + Content: "memo A content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoB, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-b", + CreatorID: user.ID, + Content: "memo B content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoC, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-c", + CreatorID: user.ID, + Content: "memo C content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // A -> B (reference) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoA.ID, + RelatedMemoID: memoB.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // A -> C (comment) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoA.ID, + RelatedMemoID: memoC.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + // Batch query with type filter: only references + refType := store.MemoRelationReference + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memoA.ID}, + Type: &refType, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + require.Equal(t, store.MemoRelationReference, relations[0].Type) + + // Batch query with type filter: only comments + commentType := store.MemoRelationComment + relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memoA.ID}, + Type: &commentType, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + require.Equal(t, store.MemoRelationComment, relations[0].Type) + + ts.Close() +} + +func TestMemoRelationListByMemoIDListBothDirections(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + memoA, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-a", + CreatorID: user.ID, + Content: "memo A content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoB, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-b", + CreatorID: user.ID, + Content: "memo B content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoX, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-x", + CreatorID: user.ID, + Content: "memo X content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // X -> A (A appears as related_memo_id) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoX.ID, + RelatedMemoID: memoA.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // A -> B (A appears as memo_id) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoA.ID, + RelatedMemoID: memoB.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Query with MemoIDList=[A]: should find both relations (A as source and A as target). + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoIDList: []int32{memoA.ID}, + }) + require.NoError(t, err) + require.Len(t, relations, 2) + + // Verify we got both directions. + memoIDs := map[int32]bool{} + relatedIDs := map[int32]bool{} + for _, r := range relations { + memoIDs[r.MemoID] = true + relatedIDs[r.RelatedMemoID] = true + } + require.True(t, memoIDs[memoX.ID], "should include X->A relation") + require.True(t, memoIDs[memoA.ID], "should include A->B relation") + require.True(t, relatedIDs[memoA.ID], "should include X->A relation") + require.True(t, relatedIDs[memoB.ID], "should include A->B relation") + + ts.Close() +} + func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) { t.Parallel() ctx := context.Background()