diff --git a/server/router/rss/rss.go b/server/router/rss/rss.go index 53ea86a20..863ac4ecd 100644 --- a/server/router/rss/rss.go +++ b/server/router/rss/rss.go @@ -211,18 +211,24 @@ func (s *RSSService) generateRSSFromMemoList(ctx context.Context, memoList []*st creatorMap[user.ID] = user } else { // Multi-user feed - batch load all unique creators - creatorIDs := make(map[int32]bool) + creatorIDList := []int32{} + creatorIDMap := make(map[int32]bool) for _, memo := range memoList[:itemCountLimit] { - creatorIDs[memo.CreatorID] = true + if !creatorIDMap[memo.CreatorID] { + creatorIDList = append(creatorIDList, memo.CreatorID) + creatorIDMap[memo.CreatorID] = true + } } - // Batch load all users with a single query by getting all users and filtering - // Note: This is more efficient than N separate queries - for creatorID := range creatorIDs { - creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &creatorID}) - if err == nil && creator != nil { - creatorMap[creatorID] = creator - } + // Batch load all users with a single query + users, err := s.Store.ListUsers(ctx, &store.FindUser{ + IDList: creatorIDList, + }) + if err != nil { + return "", lastModified, err + } + for _, creator := range users { + creatorMap[creator.ID] = creator } } diff --git a/store/db/mysql/user.go b/store/db/mysql/user.go index 4403e07ff..7334fad13 100644 --- a/store/db/mysql/user.go +++ b/store/db/mysql/user.go @@ -91,6 +91,19 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User if v := find.ID; v != nil { where, args = append(where, "`id` = ?"), append(args, *v) } + if len(find.IDList) > 0 { + placeholders := make([]string, 0, len(find.IDList)) + for range find.IDList { + placeholders = append(placeholders, "?") + } + where, args = append(where, fmt.Sprintf("`id` IN (%s)", strings.Join(placeholders, ", "))), append(args, func() []any { + list := make([]any, 0, len(find.IDList)) + for _, id := range find.IDList { + list = append(list, id) + } + return list + }()...) + } if v := find.Username; v != nil { where, args = append(where, "`username` = ?"), append(args, *v) } diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go index 0be4aa8b8..4eb41028b 100644 --- a/store/db/postgres/user.go +++ b/store/db/postgres/user.go @@ -94,6 +94,14 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User if v := find.ID; v != nil { where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } + if len(find.IDList) > 0 { + holders := make([]string, 0, len(find.IDList)) + for range find.IDList { + holders = append(holders, placeholder(len(args)+1)) + args = append(args, find.IDList[len(holders)-1]) + } + where = append(where, fmt.Sprintf("id IN (%s)", strings.Join(holders, ", "))) + } if v := find.Username; v != nil { where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v) } diff --git a/store/db/sqlite/user.go b/store/db/sqlite/user.go index b5cb906bd..0c4c555de 100644 --- a/store/db/sqlite/user.go +++ b/store/db/sqlite/user.go @@ -95,6 +95,19 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User if v := find.ID; v != nil { where, args = append(where, "id = ?"), append(args, *v) } + if len(find.IDList) > 0 { + placeholders := make([]string, 0, len(find.IDList)) + for range find.IDList { + placeholders = append(placeholders, "?") + } + where, args = append(where, fmt.Sprintf("id IN (%s)", strings.Join(placeholders, ", "))), append(args, func() []any { + list := make([]any, 0, len(find.IDList)) + for _, id := range find.IDList { + list = append(list, id) + } + return list + }()...) + } if v := find.Username; v != nil { where, args = append(where, "username = ?"), append(args, *v) } diff --git a/store/test/user_test.go b/store/test/user_test.go index 1b71af72e..4ee6ae15f 100644 --- a/store/test/user_test.go +++ b/store/test/user_test.go @@ -40,6 +40,36 @@ func TestUserStore(t *testing.T) { ts.Close() } +func TestUserListByIDList(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create 5 users + var userIDs []int32 + for i := 0; i < 5; i++ { + user, err := createTestingUserWithRole(ctx, ts, fmt.Sprintf("user_list_%d", i), store.RoleUser) + require.NoError(t, err) + userIDs = append(userIDs, user.ID) + } + + // List users by IDList (3 out of 5) + targetIDs := userIDs[1:4] + users, err := ts.ListUsers(ctx, &store.FindUser{IDList: targetIDs}) + require.NoError(t, err) + require.Equal(t, 3, len(users)) + + foundIDs := make(map[int32]bool) + for _, u := range users { + foundIDs[u.ID] = true + } + for _, id := range targetIDs { + require.True(t, foundIDs[id]) + } + + ts.Close() +} + func TestUserGetByID(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/store/user.go b/store/user.go index 8fb149539..03c56acf8 100644 --- a/store/user.go +++ b/store/user.go @@ -71,7 +71,9 @@ type UpdateUser struct { } type FindUser struct { - ID *int32 + ID *int32 + IDList []int32 + RowStatus *RowStatus Username *string Role *Role