package v1 import ( "bytes" "context" "encoding/binary" "fmt" "io" "os" "path/filepath" "regexp" "strings" "time" "github.com/lithammer/shortuuid/v4" "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/usememos/memos/internal/profile" "github.com/usememos/memos/internal/util" "github.com/usememos/memos/plugin/storage/s3" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) const ( // The upload memory buffer is 32 MiB. // It should be kept low, so RAM usage doesn't get out of control. // This is unrelated to maximum upload size limit, which is now set through system setting. MaxUploadBufferSizeBytes = 32 << 20 MebiByte = 1024 * 1024 // ThumbnailCacheFolder is the folder name where the thumbnail images are stored. ThumbnailCacheFolder = ".thumbnail_cache" ) var SupportedThumbnailMimeTypes = []string{ "image/png", "image/jpeg", } func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) { user, err := s.GetCurrentUser(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } if user == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } // Validate required fields if request.Attachment == nil { return nil, status.Errorf(codes.InvalidArgument, "attachment is required") } if request.Attachment.Filename == "" { return nil, status.Errorf(codes.InvalidArgument, "filename is required") } if !validateFilename(request.Attachment.Filename) { return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format") } if request.Attachment.Type == "" { return nil, status.Errorf(codes.InvalidArgument, "type is required") } // Use provided attachment_id or generate a new one attachmentUID := request.AttachmentId if attachmentUID == "" { attachmentUID = shortuuid.New() } create := &store.Attachment{ UID: attachmentUID, CreatorID: user.ID, Filename: request.Attachment.Filename, Type: request.Attachment.Type, } instanceStorageSetting, err := s.Store.GetInstanceStorageSetting(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get instance storage setting: %v", err) } size := binary.Size(request.Attachment.Content) uploadSizeLimit := int(instanceStorageSetting.UploadSizeLimitMb) * MebiByte if uploadSizeLimit == 0 { uploadSizeLimit = MaxUploadBufferSizeBytes } if size > uploadSizeLimit { return nil, status.Errorf(codes.InvalidArgument, "file size exceeds the limit") } create.Size = int64(size) create.Blob = request.Attachment.Content if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil { return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err) } if request.Attachment.Memo != nil { memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) } memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err) } if memo == nil { return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo) } create.MemoID = &memo.ID } attachment, err := s.Store.CreateAttachment(ctx, create) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err) } return convertAttachmentFromStore(attachment), nil } func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) { user, err := s.GetCurrentUser(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } if user == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } // Set default page size pageSize := int(request.PageSize) if pageSize <= 0 { pageSize = 50 } if pageSize > 1000 { pageSize = 1000 } // Parse page token for offset offset := 0 if request.PageToken != "" { // Simple implementation: page token is the offset as string // In production, you might want to use encrypted tokens if parsed, err := fmt.Sscanf(request.PageToken, "%d", &offset); err != nil || parsed != 1 { return nil, status.Errorf(codes.InvalidArgument, "invalid page token") } } findAttachment := &store.FindAttachment{ CreatorID: &user.ID, Limit: &pageSize, Offset: &offset, } attachments, err := s.Store.ListAttachments(ctx, findAttachment) if err != nil { return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err) } response := &v1pb.ListAttachmentsResponse{} for _, attachment := range attachments { response.Attachments = append(response.Attachments, convertAttachmentFromStore(attachment)) } // For simplicity, set total size to the number of returned attachments. // In a full implementation, you'd want a separate count query response.TotalSize = int32(len(response.Attachments)) // Set next page token if we got the full page size (indicating there might be more) if len(attachments) == pageSize { response.NextPageToken = fmt.Sprintf("%d", offset+pageSize) } return response, nil } func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttachmentRequest) (*v1pb.Attachment, error) { attachmentUID, err := ExtractAttachmentUIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err) } attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err) } if attachment == nil { return nil, status.Errorf(codes.NotFound, "attachment not found") } return convertAttachmentFromStore(attachment), nil } func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.UpdateAttachmentRequest) (*v1pb.Attachment, error) { attachmentUID, err := ExtractAttachmentUIDFromName(request.Attachment.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err) } if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 { return nil, status.Errorf(codes.InvalidArgument, "update mask is required") } attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err) } currentTs := time.Now().Unix() update := &store.UpdateAttachment{ ID: attachment.ID, UpdatedTs: ¤tTs, } for _, field := range request.UpdateMask.Paths { if field == "filename" { if !validateFilename(request.Attachment.Filename) { return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format") } update.Filename = &request.Attachment.Filename } } if err := s.Store.UpdateAttachment(ctx, update); err != nil { return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err) } return s.GetAttachment(ctx, &v1pb.GetAttachmentRequest{ Name: request.Attachment.Name, }) } func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.DeleteAttachmentRequest) (*emptypb.Empty, error) { attachmentUID, err := ExtractAttachmentUIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err) } user, err := s.GetCurrentUser(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } if user == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{ UID: &attachmentUID, CreatorID: &user.ID, }) if err != nil { return nil, status.Errorf(codes.Internal, "failed to find attachment: %v", err) } if attachment == nil { return nil, status.Errorf(codes.NotFound, "attachment not found") } // Delete the attachment from the database. if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ ID: attachment.ID, }); err != nil { return nil, status.Errorf(codes.Internal, "failed to delete attachment: %v", err) } return &emptypb.Empty{}, nil } func convertAttachmentFromStore(attachment *store.Attachment) *v1pb.Attachment { attachmentMessage := &v1pb.Attachment{ Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID), CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)), Filename: attachment.Filename, Type: attachment.Type, Size: attachment.Size, } if attachment.MemoUID != nil && *attachment.MemoUID != "" { memoName := fmt.Sprintf("%s%s", MemoNamePrefix, *attachment.MemoUID) attachmentMessage.Memo = &memoName } if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 { attachmentMessage.ExternalLink = attachment.Reference } return attachmentMessage } // SaveAttachmentBlob save the blob of attachment based on the storage config. func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Attachment) error { instanceStorageSetting, err := stores.GetInstanceStorageSetting(ctx) if err != nil { return errors.Wrap(err, "Failed to find instance storage setting") } if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_LOCAL { filepathTemplate := "assets/{timestamp}_{filename}" if instanceStorageSetting.FilepathTemplate != "" { filepathTemplate = instanceStorageSetting.FilepathTemplate } internalPath := filepathTemplate if !strings.Contains(internalPath, "{filename}") { internalPath = filepath.Join(internalPath, "{filename}") } internalPath = replaceFilenameWithPathTemplate(internalPath, create.Filename) internalPath = filepath.ToSlash(internalPath) // Ensure the directory exists. osPath := filepath.FromSlash(internalPath) if !filepath.IsAbs(osPath) { osPath = filepath.Join(profile.Data, osPath) } dir := filepath.Dir(osPath) if err = os.MkdirAll(dir, os.ModePerm); err != nil { return errors.Wrap(err, "Failed to create directory") } dst, err := os.Create(osPath) if err != nil { return errors.Wrap(err, "Failed to create file") } defer dst.Close() // Write the blob to the file. if err := os.WriteFile(osPath, create.Blob, 0644); err != nil { return errors.Wrap(err, "Failed to write file") } create.Reference = internalPath create.Blob = nil create.StorageType = storepb.AttachmentStorageType_LOCAL } else if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_S3 { s3Config := instanceStorageSetting.S3Config if s3Config == nil { return errors.Errorf("No activated external storage found") } s3Client, err := s3.NewClient(ctx, s3Config) if err != nil { return errors.Wrap(err, "Failed to create s3 client") } filepathTemplate := instanceStorageSetting.FilepathTemplate if !strings.Contains(filepathTemplate, "{filename}") { filepathTemplate = filepath.Join(filepathTemplate, "{filename}") } filepathTemplate = replaceFilenameWithPathTemplate(filepathTemplate, create.Filename) key, err := s3Client.UploadObject(ctx, filepathTemplate, create.Type, bytes.NewReader(create.Blob)) if err != nil { return errors.Wrap(err, "Failed to upload via s3 client") } presignURL, err := s3Client.PresignGetObject(ctx, key) if err != nil { return errors.Wrap(err, "Failed to presign via s3 client") } create.Reference = presignURL create.Blob = nil create.StorageType = storepb.AttachmentStorageType_S3 create.Payload = &storepb.AttachmentPayload{ Payload: &storepb.AttachmentPayload_S3Object_{ S3Object: &storepb.AttachmentPayload_S3Object{ S3Config: s3Config, Key: key, LastPresignedTime: timestamppb.New(time.Now()), }, }, } } return nil } func (s *APIV1Service) GetAttachmentBlob(attachment *store.Attachment) ([]byte, error) { // For local storage, read the file from the local disk. if attachment.StorageType == storepb.AttachmentStorageType_LOCAL { attachmentPath := filepath.FromSlash(attachment.Reference) if !filepath.IsAbs(attachmentPath) { attachmentPath = filepath.Join(s.Profile.Data, attachmentPath) } file, err := os.Open(attachmentPath) if err != nil { if os.IsNotExist(err) { return nil, errors.Wrap(err, "file not found") } return nil, errors.Wrap(err, "failed to open the file") } defer file.Close() blob, err := io.ReadAll(file) if err != nil { return nil, errors.Wrap(err, "failed to read the file") } return blob, nil } // For S3 storage, download the file from S3. if attachment.StorageType == storepb.AttachmentStorageType_S3 { if attachment.Payload == nil { return nil, errors.New("attachment payload is missing") } s3Object := attachment.Payload.GetS3Object() if s3Object == nil { return nil, errors.New("S3 object payload is missing") } if s3Object.S3Config == nil { return nil, errors.New("S3 config is missing") } if s3Object.Key == "" { return nil, errors.New("S3 object key is missing") } s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config) if err != nil { return nil, errors.Wrap(err, "failed to create S3 client") } blob, err := s3Client.GetObject(context.Background(), s3Object.Key) if err != nil { return nil, errors.Wrap(err, "failed to get object from S3") } return blob, nil } // For database storage, return the blob from the database. return attachment.Blob, nil } var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`) func replaceFilenameWithPathTemplate(path, filename string) string { t := time.Now() path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string { switch s { case "{filename}": return filename case "{timestamp}": return fmt.Sprintf("%d", t.Unix()) case "{year}": return fmt.Sprintf("%d", t.Year()) case "{month}": return fmt.Sprintf("%02d", t.Month()) case "{day}": return fmt.Sprintf("%02d", t.Day()) case "{hour}": return fmt.Sprintf("%02d", t.Hour()) case "{minute}": return fmt.Sprintf("%02d", t.Minute()) case "{second}": return fmt.Sprintf("%02d", t.Second()) case "{uuid}": return util.GenUUID() default: return s } }) return path } func validateFilename(filename string) bool { // Reject path traversal attempts and make sure no additional directories are created if !filepath.IsLocal(filename) || strings.ContainsAny(filename, "/\\") { return false } // Reject filenames starting or ending with spaces or periods if strings.HasPrefix(filename, " ") || strings.HasSuffix(filename, " ") || strings.HasPrefix(filename, ".") || strings.HasSuffix(filename, ".") { return false } return true }