mirror of https://github.com/usememos/memos.git
perf: batch load memo relations when listing memos (#5692)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3d4f793f97
commit
1e82714a52
|
|
@ -126,7 +126,11 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments)
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, nil, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -266,12 +270,19 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
|
|||
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
|
||||
}
|
||||
|
||||
// RELATIONS (batch load to avoid N+1)
|
||||
relationMap, err := s.batchConvertMemoRelations(ctx, memos)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
reactions := reactionMap[memoName]
|
||||
attachments := attachmentMap[memo.ID]
|
||||
relations := relationMap[memo.ID]
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -327,7 +338,11 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
|
|||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -462,7 +477,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments)
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -521,7 +540,8 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments); err == nil {
|
||||
deleteRelations, _ := s.loadMemoRelations(ctx, memo)
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, deleteRelations); err == nil {
|
||||
// Try to dispatch webhook when memo is deleted.
|
||||
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
|
||||
|
|
@ -725,13 +745,20 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
|
|||
attachmentMap[*attachment.MemoID] = append(attachmentMap[*attachment.MemoID], attachment)
|
||||
}
|
||||
|
||||
// RELATIONS (batch load to avoid N+1)
|
||||
relationMap, err := s.batchConvertMemoRelations(ctx, memos)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
|
||||
}
|
||||
|
||||
var memosResponse []*v1pb.Memo
|
||||
for _, m := range memos {
|
||||
memoName := memoIDToNameMap[m.ID]
|
||||
reactions := memoReactionsMap[memoName]
|
||||
attachments := attachmentMap[m.ID]
|
||||
relations := relationMap[m.ID]
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments)
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment) (*v1pb.Memo, error) {
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation) (*v1pb.Memo, error) {
|
||||
displayTs := memo.CreatedTs
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -47,20 +47,18 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
|
|||
}
|
||||
|
||||
memoMessage.Reactions = []*v1pb.Reaction{}
|
||||
|
||||
for _, reaction := range reactions {
|
||||
reactionResponse := convertReactionFromStore(reaction)
|
||||
memoMessage.Reactions = append(memoMessage.Reactions, reactionResponse)
|
||||
}
|
||||
|
||||
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo relations")
|
||||
if relations != nil {
|
||||
memoMessage.Relations = relations
|
||||
} else {
|
||||
memoMessage.Relations = []*v1pb.MemoRelation{}
|
||||
}
|
||||
memoMessage.Relations = listMemoRelationsResponse.Relations
|
||||
|
||||
memoMessage.Attachments = []*v1pb.Attachment{}
|
||||
|
||||
for _, attachment := range attachments {
|
||||
attachmentResponse := convertAttachmentFromStore(attachment)
|
||||
memoMessage.Attachments = append(memoMessage.Attachments, attachmentResponse)
|
||||
|
|
@ -75,6 +73,116 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
|
|||
return memoMessage, nil
|
||||
}
|
||||
|
||||
// batchConvertMemoRelations batch-loads relations for a list of memos and returns
|
||||
// a map from memo ID to its converted relations. This avoids N+1 queries when listing memos.
|
||||
func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*store.Memo) (map[int32][]*v1pb.MemoRelation, error) {
|
||||
if len(memos) == 0 {
|
||||
return map[int32][]*v1pb.MemoRelation{}, nil
|
||||
}
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
var memoFilter string
|
||||
if currentUser == nil {
|
||||
memoFilter = `visibility == "PUBLIC"`
|
||||
} else {
|
||||
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
}
|
||||
|
||||
memoIDs := make([]int32, len(memos))
|
||||
memoIDSet := make(map[int32]bool, len(memos))
|
||||
for i, m := range memos {
|
||||
memoIDs[i] = m.ID
|
||||
memoIDSet[m.ID] = true
|
||||
}
|
||||
|
||||
// Single batch query to get all relations involving any of these memos.
|
||||
allRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: memoIDs,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to batch list memo relations")
|
||||
}
|
||||
|
||||
// Collect all memo IDs referenced in relations that we need to resolve.
|
||||
neededIDs := make(map[int32]bool)
|
||||
for _, r := range allRelations {
|
||||
neededIDs[r.MemoID] = true
|
||||
neededIDs[r.RelatedMemoID] = true
|
||||
}
|
||||
|
||||
// Build ID→UID map from the memos we already have.
|
||||
memoIDToUID := make(map[int32]string, len(memos))
|
||||
memoIDToContent := make(map[int32]string, len(memos))
|
||||
for _, m := range memos {
|
||||
memoIDToUID[m.ID] = m.UID
|
||||
memoIDToContent[m.ID] = m.Content
|
||||
delete(neededIDs, m.ID)
|
||||
}
|
||||
|
||||
// Batch fetch any additional memos referenced by relations that we don't already have.
|
||||
if len(neededIDs) > 0 {
|
||||
extraIDs := make([]int32, 0, len(neededIDs))
|
||||
for id := range neededIDs {
|
||||
extraIDs = append(extraIDs, id)
|
||||
}
|
||||
extraMemos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: extraIDs})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to batch fetch related memos")
|
||||
}
|
||||
for _, m := range extraMemos {
|
||||
memoIDToUID[m.ID] = m.UID
|
||||
memoIDToContent[m.ID] = m.Content
|
||||
}
|
||||
}
|
||||
|
||||
// Build the result map: memo ID → its relations (both directions).
|
||||
result := make(map[int32][]*v1pb.MemoRelation, len(memos))
|
||||
for _, r := range allRelations {
|
||||
memoUID, ok1 := memoIDToUID[r.MemoID]
|
||||
relatedUID, ok2 := memoIDToUID[r.RelatedMemoID]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
|
||||
memoSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.MemoID])
|
||||
relatedSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.RelatedMemoID])
|
||||
relation := &v1pb.MemoRelation{
|
||||
Memo: &v1pb.MemoRelation_Memo{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memoUID),
|
||||
Snippet: memoSnippet,
|
||||
},
|
||||
RelatedMemo: &v1pb.MemoRelation_Memo{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedUID),
|
||||
Snippet: relatedSnippet,
|
||||
},
|
||||
Type: convertMemoRelationTypeFromStore(r.Type),
|
||||
}
|
||||
|
||||
// Add to the memo that owns this relation (both directions).
|
||||
if memoIDSet[r.MemoID] {
|
||||
result[r.MemoID] = append(result[r.MemoID], relation)
|
||||
}
|
||||
if memoIDSet[r.RelatedMemoID] {
|
||||
result[r.RelatedMemoID] = append(result[r.RelatedMemoID], relation)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// loadMemoRelations loads relations for a single memo and converts them to API format.
|
||||
func (s *APIV1Service) loadMemoRelations(ctx context.Context, memo *store.Memo) ([]*v1pb.MemoRelation, error) {
|
||||
relationMap, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return relationMap[memo.ID], nil
|
||||
}
|
||||
|
||||
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
|
||||
if property == nil {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -27,13 +27,15 @@ func NewTestService(t *testing.T) *TestService {
|
|||
// Create a test store with SQLite
|
||||
testStore := teststore.NewTestingStore(ctx, t)
|
||||
|
||||
// Create a test profile
|
||||
// Create a test profile with a temp directory for file storage,
|
||||
// so tests that create attachments don't leave artifacts in the source tree.
|
||||
testProfile := &profile.Profile{
|
||||
Demo: true,
|
||||
Version: "test-1.0.0",
|
||||
InstanceURL: "http://localhost:8080",
|
||||
Driver: "sqlite",
|
||||
DSN: ":memory:",
|
||||
Data: t.TempDir(),
|
||||
}
|
||||
|
||||
// Create APIV1Service with nil grpcServer since we're testing direct calls
|
||||
|
|
|
|||
|
|
@ -42,6 +42,18 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
|
|||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
placeholders := make([]string, len(find.MemoIDList))
|
||||
for i, id := range find.MemoIDList {
|
||||
placeholders[i] = "?"
|
||||
args = append(args, id)
|
||||
}
|
||||
inClause := strings.Join(placeholders, ", ")
|
||||
for _, id := range find.MemoIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("(`memo_id` IN (%s) OR `related_memo_id` IN (%s))", inClause, inClause))
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,20 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
|
|||
if find.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
memoPlaceholders := make([]string, len(find.MemoIDList))
|
||||
for i, id := range find.MemoIDList {
|
||||
memoPlaceholders[i] = placeholder(len(args) + 1)
|
||||
args = append(args, id)
|
||||
}
|
||||
relatedPlaceholders := make([]string, len(find.MemoIDList))
|
||||
for i, id := range find.MemoIDList {
|
||||
relatedPlaceholders[i] = placeholder(len(args) + 1)
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))",
|
||||
strings.Join(memoPlaceholders, ", "), strings.Join(relatedPlaceholders, ", ")))
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,19 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
|
|||
if find.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, find.Type)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
placeholders := make([]string, len(find.MemoIDList))
|
||||
for i, id := range find.MemoIDList {
|
||||
placeholders[i] = "?"
|
||||
args = append(args, id)
|
||||
}
|
||||
inClause := strings.Join(placeholders, ", ")
|
||||
// Duplicate args for the second IN clause.
|
||||
for _, id := range find.MemoIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, inClause))
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
engine, err := filter.DefaultEngine()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ type FindMemoRelation struct {
|
|||
RelatedMemoID *int32
|
||||
Type *MemoRelationType
|
||||
MemoFilter *string
|
||||
// MemoIDList matches relations where memo_id OR related_memo_id is in the list.
|
||||
MemoIDList []int32
|
||||
}
|
||||
|
||||
type DeleteMemoRelation struct {
|
||||
|
|
|
|||
|
|
@ -638,6 +638,270 @@ func TestMemoRelationBidirectional(t *testing.T) {
|
|||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByMemoIDList(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 3 memos.
|
||||
memoA, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-a",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo A content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoB, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-b",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo B content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoC, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-c",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo C content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoD, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-d",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo D content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// A -> B (reference)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoB.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// A -> C (comment)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoC.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// D -> B (reference) — B appears as related_memo_id
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoD.ID,
|
||||
RelatedMemoID: memoB.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Batch query for memos A and B: should return all 3 relations
|
||||
// (A->B because A is in list, A->C because A is in list, D->B because B is in list)
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memoA.ID, memoB.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 3)
|
||||
|
||||
// Batch query for memo C only: should return 1 relation (A->C because C is related_memo_id)
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memoC.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, memoA.ID, relations[0].MemoID)
|
||||
require.Equal(t, memoC.ID, relations[0].RelatedMemoID)
|
||||
|
||||
// Batch query for memo D only: should return 1 relation (D->B because D is memo_id)
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memoD.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, memoD.ID, relations[0].MemoID)
|
||||
require.Equal(t, memoB.ID, relations[0].RelatedMemoID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByMemoIDListEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-no-relations",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo with no relations",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Batch query with a memo that has no relations.
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memo.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 0)
|
||||
|
||||
// Empty MemoIDList should not filter by MemoIDList (returns based on other filters).
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByMemoIDListWithTypeFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memoA, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-a",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo A content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoB, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-b",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo B content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoC, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-c",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo C content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// A -> B (reference)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoB.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// A -> C (comment)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoC.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Batch query with type filter: only references
|
||||
refType := store.MemoRelationReference
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memoA.ID},
|
||||
Type: &refType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, store.MemoRelationReference, relations[0].Type)
|
||||
|
||||
// Batch query with type filter: only comments
|
||||
commentType := store.MemoRelationComment
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memoA.ID},
|
||||
Type: &commentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, store.MemoRelationComment, relations[0].Type)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByMemoIDListBothDirections(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memoA, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-a",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo A content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoB, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-b",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo B content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoX, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-x",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo X content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// X -> A (A appears as related_memo_id)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoX.ID,
|
||||
RelatedMemoID: memoA.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// A -> B (A appears as memo_id)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoB.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query with MemoIDList=[A]: should find both relations (A as source and A as target).
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memoA.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 2)
|
||||
|
||||
// Verify we got both directions.
|
||||
memoIDs := map[int32]bool{}
|
||||
relatedIDs := map[int32]bool{}
|
||||
for _, r := range relations {
|
||||
memoIDs[r.MemoID] = true
|
||||
relatedIDs[r.RelatedMemoID] = true
|
||||
}
|
||||
require.True(t, memoIDs[memoX.ID], "should include X->A relation")
|
||||
require.True(t, memoIDs[memoA.ID], "should include A->B relation")
|
||||
require.True(t, relatedIDs[memoA.ID], "should include X->A relation")
|
||||
require.True(t, relatedIDs[memoB.ID], "should include A->B relation")
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
|
|
|||
Loading…
Reference in New Issue