diff --git a/plugin/storage/s3/s3.go b/plugin/storage/s3/s3.go index 6d326c7e2..c6a57714b 100644 --- a/plugin/storage/s3/s3.go +++ b/plugin/storage/s3/s3.go @@ -93,6 +93,18 @@ func (c *Client) GetObject(ctx context.Context, key string) ([]byte, error) { return buffer.Bytes(), nil } +// GetObjectStream retrieves an object from S3 as a stream. +func (c *Client) GetObjectStream(ctx context.Context, key string) (io.ReadCloser, error) { + output, err := c.Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: c.Bucket, + Key: aws.String(key), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to get object") + } + return output.Body, nil +} + // DeleteObject deletes an object in S3. func (c *Client) DeleteObject(ctx context.Context, key string) error { _, err := c.Client.DeleteObject(ctx, &s3.DeleteObjectInput{ diff --git a/server/router/api/v1/user_service_stats.go b/server/router/api/v1/user_service_stats.go index ed9ecd8a4..a2009ab5a 100644 --- a/server/router/api/v1/user_service_stats.go +++ b/server/router/api/v1/user_service_stats.go @@ -42,66 +42,79 @@ func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUser memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected} } } - memos, err := s.Store.ListMemos(ctx, memoFind) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err) - } userMemoStatMap := make(map[int32]*v1pb.UserStats) - for _, memo := range memos { - // Initialize user stats if not exists - if _, exists := userMemoStatMap[memo.CreatorID]; !exists { - userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{ - Name: fmt.Sprintf("users/%d/stats", memo.CreatorID), - TagCount: make(map[string]int32), - MemoDisplayTimestamps: []*timestamppb.Timestamp{}, - PinnedMemos: []string{}, - MemoTypeStats: &v1pb.UserStats_MemoTypeStats{ - LinkCount: 0, - CodeCount: 0, - TodoCount: 0, - UndoCount: 0, - }, + limit := 1000 + offset := 0 + memoFind.Limit = &limit + memoFind.Offset = &offset + + for { + memos, err := s.Store.ListMemos(ctx, memoFind) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err) + } + if len(memos) == 0 { + break + } + + for _, memo := range memos { + // Initialize user stats if not exists + if _, exists := userMemoStatMap[memo.CreatorID]; !exists { + userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{ + Name: fmt.Sprintf("users/%d/stats", memo.CreatorID), + TagCount: make(map[string]int32), + MemoDisplayTimestamps: []*timestamppb.Timestamp{}, + PinnedMemos: []string{}, + MemoTypeStats: &v1pb.UserStats_MemoTypeStats{ + LinkCount: 0, + CodeCount: 0, + TodoCount: 0, + UndoCount: 0, + }, + } + } + + stats := userMemoStatMap[memo.CreatorID] + + // Add display timestamp + displayTs := memo.CreatedTs + if instanceMemoRelatedSetting.DisplayWithUpdateTime { + displayTs = memo.UpdatedTs + } + stats.MemoDisplayTimestamps = append(stats.MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0))) + + // Count memo stats + stats.TotalMemoCount++ + + // Count tags and other properties + if memo.Payload != nil { + for _, tag := range memo.Payload.Tags { + stats.TagCount[tag]++ + } + if memo.Payload.Property != nil { + if memo.Payload.Property.HasLink { + stats.MemoTypeStats.LinkCount++ + } + if memo.Payload.Property.HasCode { + stats.MemoTypeStats.CodeCount++ + } + if memo.Payload.Property.HasTaskList { + stats.MemoTypeStats.TodoCount++ + } + if memo.Payload.Property.HasIncompleteTasks { + stats.MemoTypeStats.UndoCount++ + } + } + } + + // Track pinned memos + if memo.Pinned { + stats.PinnedMemos = append(stats.PinnedMemos, fmt.Sprintf("users/%d/memos/%d", memo.CreatorID, memo.ID)) } } - stats := userMemoStatMap[memo.CreatorID] - - // Add display timestamp - displayTs := memo.CreatedTs - if instanceMemoRelatedSetting.DisplayWithUpdateTime { - displayTs = memo.UpdatedTs - } - stats.MemoDisplayTimestamps = append(stats.MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0))) - - // Count memo stats - stats.TotalMemoCount++ - - // Count tags and other properties - if memo.Payload != nil { - for _, tag := range memo.Payload.Tags { - stats.TagCount[tag]++ - } - if memo.Payload.Property != nil { - if memo.Payload.Property.HasLink { - stats.MemoTypeStats.LinkCount++ - } - if memo.Payload.Property.HasCode { - stats.MemoTypeStats.CodeCount++ - } - if memo.Payload.Property.HasTaskList { - stats.MemoTypeStats.TodoCount++ - } - if memo.Payload.Property.HasIncompleteTasks { - stats.MemoTypeStats.UndoCount++ - } - } - } - - // Track pinned memos - if memo.Pinned { - stats.PinnedMemos = append(stats.PinnedMemos, fmt.Sprintf("users/%d/memos/%d", memo.CreatorID, memo.ID)) - } + offset += limit } userMemoStats := []*v1pb.UserStats{} @@ -141,11 +154,6 @@ func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserSt memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected} } - memos, err := s.Store.ListMemos(ctx, memoFind) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err) - } - instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx) if err != nil { return nil, errors.Wrap(err, "failed to get instance memo related setting") @@ -158,36 +166,56 @@ func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserSt todoCount := int32(0) undoCount := int32(0) pinnedMemos := []string{} + totalMemoCount := int32(0) - for _, memo := range memos { - displayTs := memo.CreatedTs - if instanceMemoRelatedSetting.DisplayWithUpdateTime { - displayTs = memo.UpdatedTs + limit := 1000 + offset := 0 + memoFind.Limit = &limit + memoFind.Offset = &offset + + for { + memos, err := s.Store.ListMemos(ctx, memoFind) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err) } - displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0))) - // Count different memo types based on content. - if memo.Payload != nil { - for _, tag := range memo.Payload.Tags { - tagCount[tag]++ + if len(memos) == 0 { + break + } + + totalMemoCount += int32(len(memos)) + + for _, memo := range memos { + displayTs := memo.CreatedTs + if instanceMemoRelatedSetting.DisplayWithUpdateTime { + displayTs = memo.UpdatedTs } - if memo.Payload.Property != nil { - if memo.Payload.Property.HasLink { - linkCount++ + displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0))) + // Count different memo types based on content. + if memo.Payload != nil { + for _, tag := range memo.Payload.Tags { + tagCount[tag]++ } - if memo.Payload.Property.HasCode { - codeCount++ - } - if memo.Payload.Property.HasTaskList { - todoCount++ - } - if memo.Payload.Property.HasIncompleteTasks { - undoCount++ + if memo.Payload.Property != nil { + if memo.Payload.Property.HasLink { + linkCount++ + } + if memo.Payload.Property.HasCode { + codeCount++ + } + if memo.Payload.Property.HasTaskList { + todoCount++ + } + if memo.Payload.Property.HasIncompleteTasks { + undoCount++ + } } } + if memo.Pinned { + pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID)) + } } - if memo.Pinned { - pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID)) - } + + offset += limit } userStats := &v1pb.UserStats{ @@ -195,7 +223,7 @@ func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserSt MemoDisplayTimestamps: displayTimestamps, TagCount: tagCount, PinnedMemos: pinnedMemos, - TotalMemoCount: int32(len(memos)), + TotalMemoCount: totalMemoCount, MemoTypeStats: &v1pb.UserStats_MemoTypeStats{ LinkCount: linkCount, CodeCount: codeCount, diff --git a/server/router/fileserver/fileserver.go b/server/router/fileserver/fileserver.go index 6630ec601..bac7a46e1 100644 --- a/server/router/fileserver/fileserver.go +++ b/server/router/fileserver/fileserver.go @@ -340,6 +340,55 @@ func (*FileServerService) isImageType(mimeType string) bool { return mimeType == "image/png" || mimeType == "image/jpeg" } +// getAttachmentReader returns a reader for the attachment content. +func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) { + // For local storage, read the file from the local disk. + if attachment.StorageType == storepb.AttachmentStorageType_LOCAL { + attachmentPath := filepath.FromSlash(attachment.Reference) + if !filepath.IsAbs(attachmentPath) { + attachmentPath = filepath.Join(s.Profile.Data, attachmentPath) + } + + file, err := os.Open(attachmentPath) + if err != nil { + if os.IsNotExist(err) { + return nil, errors.Wrap(err, "file not found") + } + return nil, errors.Wrap(err, "failed to open the file") + } + return file, nil + } + // For S3 storage, download the file from S3. + if attachment.StorageType == storepb.AttachmentStorageType_S3 { + if attachment.Payload == nil { + return nil, errors.New("attachment payload is missing") + } + s3Object := attachment.Payload.GetS3Object() + if s3Object == nil { + return nil, errors.New("S3 object payload is missing") + } + if s3Object.S3Config == nil { + return nil, errors.New("S3 config is missing") + } + if s3Object.Key == "" { + return nil, errors.New("S3 object key is missing") + } + + s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config) + if err != nil { + return nil, errors.Wrap(err, "failed to create S3 client") + } + + reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key) + if err != nil { + return nil, errors.Wrap(err, "failed to get object from S3") + } + return reader, nil + } + // For database storage, return the blob from the database. + return io.NopCloser(bytes.NewReader(attachment.Blob)), nil +} + // getAttachmentBlob retrieves the binary content of an attachment from storage. func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) { // For local storage, read the file from the local disk. @@ -441,13 +490,14 @@ func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachme } // Generate the thumbnail - blob, err := s.getAttachmentBlob(attachment) + reader, err := s.getAttachmentReader(attachment) if err != nil { - return nil, errors.Wrap(err, "failed to get attachment blob") + return nil, errors.Wrap(err, "failed to get attachment reader") } + defer reader.Close() // Decode image - this is memory intensive - img, err := imaging.Decode(bytes.NewReader(blob), imaging.AutoOrientation(true)) + img, err := imaging.Decode(reader, imaging.AutoOrientation(true)) if err != nil { return nil, errors.Wrap(err, "failed to decode thumbnail image") }