diff --git a/.dockerignore b/.dockerignore index a0ae8ea54..d465134d4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,13 @@ web/node_modules +web/dist +.git +.github +build/ +tmp/ +memos +*.md +.gitignore +.golangci.yaml +.dockerignore +docs/ +.DS_Store \ No newline at end of file diff --git a/.github/workflows/backend-tests.yml b/.github/workflows/backend-tests.yml index a79e2c85f..4ba0fb980 100644 --- a/.github/workflows/backend-tests.yml +++ b/.github/workflows/backend-tests.yml @@ -11,37 +11,79 @@ on: - "go.sum" - "**.go" +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: go-static-checks: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-go@v6 with: go-version: 1.25 - check-latest: true cache: true + cache-dependency-path: go.sum + - name: Verify go.mod is tidy run: | go mod tidy -go=1.25 git diff --exit-code + - name: golangci-lint - uses: golangci/golangci-lint-action@v8 + uses: golangci/golangci-lint-action@v9 with: version: v2.4.0 - args: --verbose --timeout=3m - skip-cache: true + args: --timeout=3m go-tests: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test-group: + - store + - server + - plugin + - other steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-go@v6 with: go-version: 1.25 - check-latest: true cache: true - - name: Run all tests - run: go test -v ./... | tee test.log; exit ${PIPESTATUS[0]} - - name: Pretty print tests running time - run: grep --color=never -e '--- PASS:' -e '--- FAIL:' test.log | sed 's/[:()]//g' | awk '{print $2,$3,$4}' | sort -t' ' -nk3 -r | awk '{sum += $3; print $1,$2,$3,sum"s"}' + cache-dependency-path: go.sum + + - name: Run tests - ${{ matrix.test-group }} + run: | + case "${{ matrix.test-group }}" in + store) + # Run store tests for all drivers (sqlite, mysql, postgres) + # The TestMain in store/test runs all drivers when DRIVER is not set + # Note: We run without -race for container tests due to testcontainers race issues + go test -v -coverprofile=coverage.out -covermode=atomic ./store/... + ;; + server) + go test -v -race -coverprofile=coverage.out -covermode=atomic ./server/... + ;; + plugin) + go test -v -race -coverprofile=coverage.out -covermode=atomic ./plugin/... + ;; + other) + go test -v -race -coverprofile=coverage.out -covermode=atomic \ + ./cmd/... ./internal/... ./proto/... + ;; + esac + env: + DRIVER: ${{ matrix.test-group == 'store' && '' || 'sqlite' }} + + - name: Upload coverage + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: codecov/codecov-action@v5 + with: + files: ./coverage.out + flags: ${{ matrix.test-group }} + fail_ci_if_error: false diff --git a/.github/workflows/build-and-push-canary-image.yml b/.github/workflows/build-and-push-canary-image.yml index 9908bc412..1b8c0b980 100644 --- a/.github/workflows/build-and-push-canary-image.yml +++ b/.github/workflows/build-and-push-canary-image.yml @@ -4,36 +4,128 @@ on: push: branches: [main] -env: - DOCKER_PLATFORMS: | - linux/amd64 - linux/arm64 - concurrency: group: ${{ github.workflow }}-${{ github.repository }} cancel-in-progress: true jobs: - build-and-push-canary-image: + build-frontend: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: pnpm/action-setup@v4.2.0 + with: + version: 10 + - uses: actions/setup-node@v6 + with: + node-version: "22" + cache: pnpm + cache-dependency-path: "web/pnpm-lock.yaml" + - name: Get pnpm store directory + id: pnpm-cache + shell: bash + run: echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT + - name: Setup pnpm cache + uses: actions/cache@v5 + with: + path: ${{ steps.pnpm-cache.outputs.STORE_PATH }} + key: ${{ runner.os }}-pnpm-store-${{ hashFiles('web/pnpm-lock.yaml') }} + restore-keys: ${{ runner.os }}-pnpm-store- + - run: pnpm install --frozen-lockfile + working-directory: web + - name: Run frontend build + run: pnpm release + working-directory: web + + - name: Upload frontend artifacts + uses: actions/upload-artifact@v6 + with: + name: frontend-dist + path: server/router/frontend/dist + retention-days: 1 + + build-push: + needs: build-frontend + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + strategy: + fail-fast: false + matrix: + platform: + - linux/amd64 + - linux/arm64 + steps: + - uses: actions/checkout@v6 + + - name: Download frontend artifacts + uses: actions/download-artifact@v7 + with: + name: frontend-dist + path: server/router/frontend/dist + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_HUB_USERNAME }} + password: ${{ secrets.DOCKER_HUB_TOKEN }} + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ github.token }} + + - name: Build and push by digest + id: build + uses: docker/build-push-action@v6 + with: + context: . + file: ./scripts/Dockerfile + platforms: ${{ matrix.platform }} + cache-from: type=gha,scope=build-${{ matrix.platform }} + cache-to: type=gha,mode=max,scope=build-${{ matrix.platform }} + outputs: type=image,name=neosmemo/memos,push-by-digest=true,name-canonical=true,push=true + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: actions/upload-artifact@v6 + with: + name: digests-${{ strategy.job-index }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + + merge: + needs: build-push runs-on: ubuntu-latest permissions: contents: read packages: write steps: - - uses: actions/checkout@v5 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + - name: Download digests + uses: actions/download-artifact@v7 with: - platforms: ${{ env.DOCKER_PLATFORMS }} + pattern: digests-* + merge-multiple: true + path: /tmp/digests - name: Set up Docker Buildx - id: buildx uses: docker/setup-buildx-action@v3 - with: - version: latest - install: true - platforms: ${{ env.DOCKER_PLATFORMS }} - name: Docker meta id: meta @@ -60,32 +152,15 @@ jobs: username: ${{ github.actor }} password: ${{ github.token }} - # Frontend build. - - uses: pnpm/action-setup@v4.1.0 - with: - version: 10 - - uses: actions/setup-node@v5 - with: - node-version: "22" - cache: pnpm - cache-dependency-path: "web/pnpm-lock.yaml" - - run: pnpm install - working-directory: web - - name: Run frontend build - run: pnpm release - working-directory: web + - name: Create manifest list and push + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf 'neosmemo/memos@sha256:%s ' *) + env: + DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }} - - name: Build and Push - id: docker_build - uses: docker/build-push-action@v6 - with: - context: . - file: ./scripts/Dockerfile - platforms: ${{ env.DOCKER_PLATFORMS }} - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max - build-args: | - BUILDKIT_INLINE_CACHE=1 + - name: Inspect images + run: | + docker buildx imagetools inspect neosmemo/memos:canary + docker buildx imagetools inspect ghcr.io/usememos/memos:canary diff --git a/.github/workflows/build-and-push-stable-image.yml b/.github/workflows/build-and-push-stable-image.yml index 52a94a920..db6ad0556 100644 --- a/.github/workflows/build-and-push-stable-image.yml +++ b/.github/workflows/build-and-push-stable-image.yml @@ -7,41 +7,84 @@ on: tags: - "v*.*.*" -env: - DOCKER_PLATFORMS: | - linux/amd64 - linux/arm/v7 - linux/arm64 - jobs: - build-and-push-image: + prepare: + runs-on: ubuntu-latest + outputs: + version: ${{ steps.version.outputs.version }} + steps: + - name: Extract version + id: version + run: | + if [[ "$GITHUB_REF_TYPE" == "tag" ]]; then + echo "version=${GITHUB_REF_NAME#v}" >> $GITHUB_OUTPUT + else + echo "version=${GITHUB_REF_NAME#release/}" >> $GITHUB_OUTPUT + fi + + build-frontend: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: pnpm/action-setup@v4.2.0 + with: + version: 10 + - uses: actions/setup-node@v6 + with: + node-version: "22" + cache: pnpm + cache-dependency-path: "web/pnpm-lock.yaml" + - name: Get pnpm store directory + id: pnpm-cache + shell: bash + run: echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT + - name: Setup pnpm cache + uses: actions/cache@v5 + with: + path: ${{ steps.pnpm-cache.outputs.STORE_PATH }} + key: ${{ runner.os }}-pnpm-store-${{ hashFiles('web/pnpm-lock.yaml') }} + restore-keys: ${{ runner.os }}-pnpm-store- + - run: pnpm install --frozen-lockfile + working-directory: web + - name: Run frontend build + run: pnpm release + working-directory: web + + - name: Upload frontend artifacts + uses: actions/upload-artifact@v6 + with: + name: frontend-dist + path: server/router/frontend/dist + retention-days: 1 + + build-push: + needs: [prepare, build-frontend] runs-on: ubuntu-latest permissions: contents: read packages: write + strategy: + fail-fast: false + matrix: + platform: + - linux/amd64 + - linux/arm/v7 + - linux/arm64 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 + + - name: Download frontend artifacts + uses: actions/download-artifact@v7 + with: + name: frontend-dist + path: server/router/frontend/dist - name: Set up QEMU uses: docker/setup-qemu-action@v3 - with: - platforms: ${{ env.DOCKER_PLATFORMS }} - name: Set up Docker Buildx - id: buildx uses: docker/setup-buildx-action@v3 - with: - version: latest - install: true - platforms: ${{ env.DOCKER_PLATFORMS }} - - - name: Extract version - run: | - if [[ "$GITHUB_REF_TYPE" == "tag" ]]; then - echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV - else - echo "VERSION=${GITHUB_REF_NAME#release/}" >> $GITHUB_ENV - fi - name: Login to Docker Hub uses: docker/login-action@v3 @@ -56,6 +99,48 @@ jobs: username: ${{ github.actor }} password: ${{ github.token }} + - name: Build and push by digest + id: build + uses: docker/build-push-action@v6 + with: + context: . + file: ./scripts/Dockerfile + platforms: ${{ matrix.platform }} + cache-from: type=gha,scope=build-${{ matrix.platform }} + cache-to: type=gha,mode=max,scope=build-${{ matrix.platform }} + outputs: type=image,name=neosmemo/memos,push-by-digest=true,name-canonical=true,push=true + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: actions/upload-artifact@v6 + with: + name: digests-${{ strategy.job-index }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + + merge: + needs: [prepare, build-push] + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + steps: + - name: Download digests + uses: actions/download-artifact@v7 + with: + pattern: digests-* + merge-multiple: true + path: /tmp/digests + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Docker meta id: meta uses: docker/metadata-action@v5 @@ -64,40 +149,36 @@ jobs: neosmemo/memos ghcr.io/usememos/memos tags: | - type=semver,pattern={{version}},value=${{ env.VERSION }} - type=semver,pattern={{major}}.{{minor}},value=${{ env.VERSION }} + type=semver,pattern={{version}},value=${{ needs.prepare.outputs.version }} + type=semver,pattern={{major}}.{{minor}},value=${{ needs.prepare.outputs.version }} type=raw,value=stable flavor: | latest=false labels: | - org.opencontainers.image.version=${{ env.VERSION }} + org.opencontainers.image.version=${{ needs.prepare.outputs.version }} - # Frontend build. - - uses: pnpm/action-setup@v4.1.0 + - name: Login to Docker Hub + uses: docker/login-action@v3 with: - version: 10 - - uses: actions/setup-node@v5 - with: - node-version: "22" - cache: pnpm - cache-dependency-path: "web/pnpm-lock.yaml" - - run: pnpm install - working-directory: web - - name: Run frontend build - run: pnpm release - working-directory: web + username: ${{ secrets.DOCKER_HUB_USERNAME }} + password: ${{ secrets.DOCKER_HUB_TOKEN }} - - name: Build and Push - id: docker_build - uses: docker/build-push-action@v6 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 with: - context: . - file: ./scripts/Dockerfile - platforms: ${{ env.DOCKER_PLATFORMS }} - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max - build-args: | - BUILDKIT_INLINE_CACHE=1 + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ github.token }} + + - name: Create manifest list and push + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf 'neosmemo/memos@sha256:%s ' *) + env: + DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }} + + - name: Inspect images + run: | + docker buildx imagetools inspect neosmemo/memos:stable + docker buildx imagetools inspect ghcr.io/usememos/memos:stable diff --git a/.github/workflows/frontend-tests.yml b/.github/workflows/frontend-tests.yml index 2bb6e59b7..04e807949 100644 --- a/.github/workflows/frontend-tests.yml +++ b/.github/workflows/frontend-tests.yml @@ -13,11 +13,11 @@ jobs: static-checks: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 - - uses: pnpm/action-setup@v4.1.0 + - uses: actions/checkout@v6 + - uses: pnpm/action-setup@v4.2.0 with: version: 9 - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: "20" cache: pnpm @@ -31,11 +31,11 @@ jobs: frontend-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 - - uses: pnpm/action-setup@v4.1.0 + - uses: actions/checkout@v6 + - uses: pnpm/action-setup@v4.2.0 with: version: 9 - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: "20" cache: pnpm diff --git a/.github/workflows/proto-linter.yml b/.github/workflows/proto-linter.yml index d183fb31d..0bb8e977e 100644 --- a/.github/workflows/proto-linter.yml +++ b/.github/workflows/proto-linter.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup buf diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index b1bf2c238..ad330a6f7 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -11,7 +11,7 @@ jobs: issues: write steps: - - uses: actions/stale@v10.0.0 + - uses: actions/stale@v10.1.1 with: days-before-issue-stale: 14 days-before-issue-close: 7 diff --git a/AGENTS.md b/AGENTS.md index fe20cce0d..3ed1a6170 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -165,7 +165,7 @@ type Driver interface { 4. Demo mode: Seed with demo data **Schema Versioning:** -- Stored in `instance_setting` table (key: `bb.general.version`) +- Stored in `system_setting` table - Format: `major.minor.patch` - Migration files: `store/migration/{driver}/{version}/NN__description.sql` - See: `store/migrator.go:21-414` @@ -191,7 +191,7 @@ cd proto && buf generate ```bash # Start dev server -go run ./cmd/memos --mode dev --port 8081 +go run ./cmd/memos --port 8081 # Run all tests go test ./... @@ -458,7 +458,7 @@ cd web && pnpm lint | Variable | Default | Description | |----------|----------|-------------| -| `MEMOS_MODE` | `dev` | Mode: `dev`, `prod`, `demo` | +| `MEMOS_DEMO` | `false` | Enable demo mode | | `MEMOS_PORT` | `8081` | HTTP port | | `MEMOS_ADDR` | `` | Bind address (empty = all) | | `MEMOS_DATA` | `~/.memos` | Data directory | @@ -503,13 +503,6 @@ cd web && pnpm lint ## Common Tasks -### Debugging Database Issues - -1. Check connection string in logs -2. Verify `store/db/{driver}/migration/` files exist -3. Check schema version: `SELECT * FROM instance_setting WHERE key = 'bb.general.version'` -4. Test migration: `go test ./store/test/... -v` - ### Debugging API Issues 1. Check Connect interceptor logs: `server/router/api/v1/connect_interceptors.go:79-105` @@ -571,7 +564,7 @@ Each plugin has its own README with usage examples. ## Security Notes -- JWT secrets must be kept secret (`MEMOS_MODE=prod` generates random secret) +- JWT secrets must be kept secret (generated on first run in production mode) - Personal Access Tokens stored as SHA-256 hashes in database - CSRF protection via SameSite cookies - CORS enabled for all origins (configure for production) diff --git a/README.md b/README.md index dbc8ee2b0..f36e6c9f2 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,10 @@ An open-source, self-hosted note-taking service. Your thoughts, your data, your --- -[**LambdaTest** - Cross-browser testing cloud](https://www.lambdatest.com/?utm_source=memos&utm_medium=sponsor) +[**TestMu AI** - The world’s first full-stack Agentic AI Quality Engineering platform](https://www.testmu.ai/?utm_source=memos&utm_medium=sponsor) - - LambdaTest - Cross-browser testing cloud + + TestMu AI ## Overview diff --git a/cmd/memos/main.go b/cmd/memos/main.go index 48dad202d..bb182b8f7 100644 --- a/cmd/memos/main.go +++ b/cmd/memos/main.go @@ -25,7 +25,7 @@ var ( Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`, Run: func(_ *cobra.Command, _ []string) { instanceProfile := &profile.Profile{ - Mode: viper.GetString("mode"), + Demo: viper.GetBool("demo"), Addr: viper.GetString("addr"), Port: viper.GetInt("port"), UNIXSock: viper.GetString("unix-sock"), @@ -33,10 +33,12 @@ var ( Driver: viper.GetString("driver"), DSN: viper.GetString("dsn"), InstanceURL: viper.GetString("instance-url"), - Version: version.GetCurrentVersion(viper.GetString("mode")), } + instanceProfile.Version = version.GetCurrentVersion() + if err := instanceProfile.Validate(); err != nil { - panic(err) + slog.Error("failed to validate profile", "error", err) + return } ctx, cancel := context.WithCancel(context.Background()) @@ -71,6 +73,7 @@ var ( if err != http.ErrServerClosed { slog.Error("failed to start server", "error", err) cancel() + return } } @@ -89,11 +92,11 @@ var ( ) func init() { - viper.SetDefault("mode", "dev") + viper.SetDefault("demo", false) viper.SetDefault("driver", "sqlite") viper.SetDefault("port", 8081) - rootCmd.PersistentFlags().String("mode", "dev", `mode of server, can be "prod" or "dev" or "demo"`) + rootCmd.PersistentFlags().Bool("demo", false, "enable demo mode") rootCmd.PersistentFlags().String("addr", "", "address of server") rootCmd.PersistentFlags().Int("port", 8081, "port of server") rootCmd.PersistentFlags().String("unix-sock", "", "path to the unix socket, overrides --addr and --port") @@ -102,7 +105,7 @@ func init() { rootCmd.PersistentFlags().String("dsn", "", "database source name(aka. DSN)") rootCmd.PersistentFlags().String("instance-url", "", "the url of your memos instance") - if err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")); err != nil { + if err := viper.BindPFlag("demo", rootCmd.PersistentFlags().Lookup("demo")); err != nil { panic(err) } if err := viper.BindPFlag("addr", rootCmd.PersistentFlags().Lookup("addr")); err != nil { @@ -129,15 +132,12 @@ func init() { viper.SetEnvPrefix("memos") viper.AutomaticEnv() - if err := viper.BindEnv("instance-url", "MEMOS_INSTANCE_URL"); err != nil { - panic(err) - } } func printGreetings(profile *profile.Profile) { fmt.Printf("Memos %s started successfully!\n", profile.Version) - if profile.IsDev() { + if profile.Demo { fmt.Fprint(os.Stderr, "Development mode is enabled\n") if profile.DSN != "" { fmt.Fprintf(os.Stderr, "Database: %s\n", profile.DSN) @@ -147,7 +147,6 @@ func printGreetings(profile *profile.Profile) { // Server information fmt.Printf("Data directory: %s\n", profile.Data) fmt.Printf("Database driver: %s\n", profile.Driver) - fmt.Printf("Mode: %s\n", profile.Mode) // Connection information if len(profile.UNIXSock) == 0 { @@ -170,6 +169,6 @@ func printGreetings(profile *profile.Profile) { func main() { if err := rootCmd.Execute(); err != nil { - panic(err) + os.Exit(1) } } diff --git a/go.mod b/go.mod index cd7f5b5e1..fc85404e7 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.18.16 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.4 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 + github.com/docker/docker v28.5.1+incompatible github.com/go-sql-driver/mysql v1.9.3 github.com/google/cel-go v0.26.1 github.com/google/uuid v1.6.0 @@ -50,7 +51,6 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect diff --git a/internal/profile/profile.go b/internal/profile/profile.go index 8d551d669..30579a313 100644 --- a/internal/profile/profile.go +++ b/internal/profile/profile.go @@ -13,8 +13,8 @@ import ( // Profile is the configuration to start main server. type Profile struct { - // Mode can be "prod" or "dev" or "demo" - Mode string + // Demo indicates if the server is in demo mode + Demo bool // Addr is the binding address for server Addr string // Port is the binding port for server @@ -34,15 +34,12 @@ type Profile struct { InstanceURL string } -func (p *Profile) IsDev() bool { - return p.Mode != "prod" -} - func checkDataDir(dataDir string) (string, error) { // Convert to absolute path if relative path is supplied. if !filepath.IsAbs(dataDir) { - relativeDir := filepath.Join(filepath.Dir(os.Args[0]), dataDir) - absDir, err := filepath.Abs(relativeDir) + // Use current working directory, not the binary's directory + // This ensures we use the actual working directory where the process runs + absDir, err := filepath.Abs(dataDir) if err != nil { return "", err } @@ -58,21 +55,35 @@ func checkDataDir(dataDir string) (string, error) { } func (p *Profile) Validate() error { - if p.Mode != "demo" && p.Mode != "dev" && p.Mode != "prod" { - p.Mode = "demo" - } - - if p.Mode == "prod" && p.Data == "" { + // Set default data directory if not specified + if p.Data == "" { if runtime.GOOS == "windows" { p.Data = filepath.Join(os.Getenv("ProgramData"), "memos") - if _, err := os.Stat(p.Data); os.IsNotExist(err) { - if err := os.MkdirAll(p.Data, 0770); err != nil { - slog.Error("failed to create data directory", slog.String("data", p.Data), slog.String("error", err.Error())) - return err - } - } } else { - p.Data = "/var/opt/memos" + // On Linux/macOS, check if /var/opt/memos exists and is writable (Docker scenario) + if info, err := os.Stat("/var/opt/memos"); err == nil && info.IsDir() { + // Check if we can write to this directory + testFile := filepath.Join("/var/opt/memos", ".write-test") + if err := os.WriteFile(testFile, []byte("test"), 0600); err == nil { + os.Remove(testFile) + p.Data = "/var/opt/memos" + } else { + // /var/opt/memos exists but is not writable, use current directory + slog.Warn("/var/opt/memos is not writable, using current directory") + p.Data = "." + } + } else { + // /var/opt/memos doesn't exist, use current directory (local development) + p.Data = "." + } + } + } + + // Create data directory if it doesn't exist + if _, err := os.Stat(p.Data); os.IsNotExist(err) { + if err := os.MkdirAll(p.Data, 0770); err != nil { + slog.Error("failed to create data directory", slog.String("data", p.Data), slog.String("error", err.Error())) + return err } } @@ -84,7 +95,11 @@ func (p *Profile) Validate() error { p.Data = dataDir if p.Driver == "sqlite" && p.DSN == "" { - dbFile := fmt.Sprintf("memos_%s.db", p.Mode) + mode := "prod" + if p.Demo { + mode = "demo" + } + dbFile := fmt.Sprintf("memos_%s.db", mode) p.DSN = filepath.Join(dataDir, dbFile) } diff --git a/internal/version/version.go b/internal/version/version.go index 2fdc62aef..af1bbc2ec 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -9,15 +9,9 @@ import ( // Version is the service current released version. // Semantic versioning: https://semver.org/ -var Version = "0.25.3" +var Version = "0.26.0" -// DevVersion is the service current development version. -var DevVersion = "0.25.3" - -func GetCurrentVersion(mode string) string { - if mode == "dev" || mode == "demo" { - return DevVersion - } +func GetCurrentVersion() string { return Version } diff --git a/plugin/email/email_test.go b/plugin/email/email_test.go index 7927512ec..f3eebee09 100644 --- a/plugin/email/email_test.go +++ b/plugin/email/email_test.go @@ -1,11 +1,11 @@ package email import ( - "sync" "testing" "time" "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" ) func TestSend(t *testing.T) { @@ -106,34 +106,22 @@ func TestSendAsyncConcurrent(t *testing.T) { FromEmail: "test@example.com", } - // Send multiple emails concurrently - var wg sync.WaitGroup + g := errgroup.Group{} count := 5 for i := 0; i < count; i++ { - wg.Add(1) - go func() { - defer wg.Done() + g.Go(func() error { message := &Message{ To: []string{"recipient@example.com"}, Subject: "Concurrent Test", Body: "Test body", } SendAsync(config, message) - }() + return nil + }) } - // Should complete without deadlock - done := make(chan bool) - go func() { - wg.Wait() - done <- true - }() - - select { - case <-done: - // Success - case <-time.After(1 * time.Second): - t.Fatal("SendAsync calls did not complete in time") + if err := g.Wait(); err != nil { + t.Fatalf("SendAsync calls failed: %v", err) } } diff --git a/plugin/filter/ir.go b/plugin/filter/ir.go index cfdefc9d4..10cb13df1 100644 --- a/plugin/filter/ir.go +++ b/plugin/filter/ir.go @@ -114,3 +114,46 @@ type FunctionValue struct { } func (*FunctionValue) isValueExpr() {} + +// ListComprehensionCondition represents CEL macros like exists(), all(), filter(). +type ListComprehensionCondition struct { + Kind ComprehensionKind + Field string // The list field to iterate over (e.g., "tags") + IterVar string // The iteration variable name (e.g., "t") + Predicate PredicateExpr // The predicate to evaluate for each element +} + +func (*ListComprehensionCondition) isCondition() {} + +// ComprehensionKind enumerates the types of list comprehensions. +type ComprehensionKind string + +const ( + ComprehensionExists ComprehensionKind = "exists" +) + +// PredicateExpr represents predicates used in comprehensions. +type PredicateExpr interface { + isPredicateExpr() +} + +// StartsWithPredicate represents t.startsWith("prefix"). +type StartsWithPredicate struct { + Prefix string +} + +func (*StartsWithPredicate) isPredicateExpr() {} + +// EndsWithPredicate represents t.endsWith("suffix"). +type EndsWithPredicate struct { + Suffix string +} + +func (*EndsWithPredicate) isPredicateExpr() {} + +// ContainsPredicate represents t.contains("substring"). +type ContainsPredicate struct { + Substring string +} + +func (*ContainsPredicate) isPredicateExpr() {} diff --git a/plugin/filter/parser.go b/plugin/filter/parser.go index 76bb1630b..36e52d1db 100644 --- a/plugin/filter/parser.go +++ b/plugin/filter/parser.go @@ -36,6 +36,8 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) { return nil, errors.Errorf("identifier %q is not boolean", name) } return &FieldPredicateCondition{Field: name}, nil + case *exprv1.Expr_ComprehensionExpr: + return buildComprehensionCondition(v.ComprehensionExpr, schema) default: return nil, errors.New("unsupported top-level expression") } @@ -415,3 +417,170 @@ func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) { func timeNowUnix() int64 { return time.Now().Unix() } + +// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.). +func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) { + // Determine the comprehension kind by examining the loop initialization and step + kind, err := detectComprehensionKind(comp) + if err != nil { + return nil, err + } + + // Get the field being iterated over + iterRangeIdent := comp.IterRange.GetIdentExpr() + if iterRangeIdent == nil { + return nil, errors.New("comprehension range must be a field identifier") + } + fieldName := iterRangeIdent.GetName() + + // Validate the field + field, ok := schema.Field(fieldName) + if !ok { + return nil, errors.Errorf("unknown field %q in comprehension", fieldName) + } + if field.Kind != FieldKindJSONList { + return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName) + } + + // Extract the predicate from the loop step + predicate, err := extractPredicate(comp, schema) + if err != nil { + return nil, err + } + + return &ListComprehensionCondition{ + Kind: kind, + Field: fieldName, + IterVar: comp.IterVar, + Predicate: predicate, + }, nil +} + +// detectComprehensionKind determines if this is an exists() macro. +// Only exists() is currently supported. +func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) { + // Check the accumulator initialization + accuInit := comp.AccuInit.GetConstExpr() + if accuInit == nil { + return "", errors.New("comprehension accumulator must be initialized with a constant") + } + + // exists() starts with false and uses OR (||) in loop step + if !accuInit.GetBoolValue() { + if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" { + return ComprehensionExists, nil + } + } + + // all() starts with true and uses AND (&&) - not supported + if accuInit.GetBoolValue() { + if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" { + return "", errors.New("all() comprehension is not supported; use exists() instead") + } + } + + return "", errors.New("unsupported comprehension type; only exists() is supported") +} + +// extractPredicate extracts the predicate expression from the comprehension loop step. +func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) { + // The loop step is: @result || predicate(t) for exists + // or: @result && predicate(t) for all + step := comp.LoopStep.GetCallExpr() + if step == nil { + return nil, errors.New("comprehension loop step must be a call expression") + } + + if len(step.Args) != 2 { + return nil, errors.New("comprehension loop step must have two arguments") + } + + // The predicate is the second argument + predicateExpr := step.Args[1] + predicateCall := predicateExpr.GetCallExpr() + if predicateCall == nil { + return nil, errors.New("comprehension predicate must be a function call") + } + + // Handle different predicate functions + switch predicateCall.Function { + case "startsWith": + return buildStartsWithPredicate(predicateCall, comp.IterVar) + case "endsWith": + return buildEndsWithPredicate(predicateCall, comp.IterVar) + case "contains": + return buildContainsPredicate(predicateCall, comp.IterVar) + default: + return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function) + } +} + +// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix"). +func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { + // Verify the target is the iteration variable + if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar { + return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar) + } + + if len(call.Args) != 1 { + return nil, errors.New("startsWith expects exactly one argument") + } + + prefix, err := getConstValue(call.Args[0]) + if err != nil { + return nil, errors.Wrap(err, "startsWith argument must be a constant string") + } + + prefixStr, ok := prefix.(string) + if !ok { + return nil, errors.New("startsWith argument must be a string") + } + + return &StartsWithPredicate{Prefix: prefixStr}, nil +} + +// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix"). +func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { + if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar { + return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar) + } + + if len(call.Args) != 1 { + return nil, errors.New("endsWith expects exactly one argument") + } + + suffix, err := getConstValue(call.Args[0]) + if err != nil { + return nil, errors.Wrap(err, "endsWith argument must be a constant string") + } + + suffixStr, ok := suffix.(string) + if !ok { + return nil, errors.New("endsWith argument must be a string") + } + + return &EndsWithPredicate{Suffix: suffixStr}, nil +} + +// buildContainsPredicate extracts the pattern from t.contains("substring"). +func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { + if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar { + return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar) + } + + if len(call.Args) != 1 { + return nil, errors.New("contains expects exactly one argument") + } + + substring, err := getConstValue(call.Args[0]) + if err != nil { + return nil, errors.Wrap(err, "contains argument must be a constant string") + } + + substringStr, ok := substring.(string) + if !ok { + return nil, errors.New("contains argument must be a string") + } + + return &ContainsPredicate{Substring: substringStr}, nil +} diff --git a/plugin/filter/render.go b/plugin/filter/render.go index 9d3bc60af..c00de417e 100644 --- a/plugin/filter/render.go +++ b/plugin/filter/render.go @@ -74,6 +74,8 @@ func (r *renderer) renderCondition(cond Condition) (renderResult, error) { return r.renderElementInCondition(c) case *ContainsCondition: return r.renderContainsCondition(c) + case *ListComprehensionCondition: + return r.renderListComprehension(c) case *ConstantCondition: if c.Value { return renderResult{trivial: true}, nil @@ -461,13 +463,108 @@ func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResul } } +func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) { + field, ok := r.schema.Field(cond.Field) + if !ok { + return renderResult{}, errors.Errorf("unknown field %q", cond.Field) + } + + if field.Kind != FieldKindJSONList { + return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field) + } + + // Render based on predicate type + switch pred := cond.Predicate.(type) { + case *StartsWithPredicate: + return r.renderTagStartsWith(field, pred.Prefix, cond.Kind) + case *EndsWithPredicate: + return r.renderTagEndsWith(field, pred.Suffix, cond.Kind) + case *ContainsPredicate: + return r.renderTagContains(field, pred.Substring, cond.Kind) + default: + return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred) + } +} + +// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")). +func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) { + arrayExpr := jsonArrayExpr(r.dialect, field) + + switch r.dialect { + case DialectSQLite, DialectMySQL: + // Match exact tag or tags with this prefix (hierarchical support) + exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix)) + prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix)) + condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil + + case DialectPostgres: + // Use PostgreSQL's powerful JSON operators + exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix))) + prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix))) + condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil + + default: + return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect) + } +} + +// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")). +func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) { + arrayExpr := jsonArrayExpr(r.dialect, field) + pattern := fmt.Sprintf(`%%%s"%%`, suffix) + + likeExpr := r.buildJSONArrayLike(arrayExpr, pattern) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil +} + +// renderTagContains generates SQL for tags.exists(t, t.contains("substring")). +func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) { + arrayExpr := jsonArrayExpr(r.dialect, field) + pattern := fmt.Sprintf(`%%%s%%`, substring) + + likeExpr := r.buildJSONArrayLike(arrayExpr, pattern) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil +} + +// buildJSONArrayLike builds a LIKE expression for matching within a JSON array. +// Returns the LIKE clause without NULL/empty checks. +func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string { + switch r.dialect { + case DialectSQLite, DialectMySQL: + return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern)) + case DialectPostgres: + return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern)) + default: + return "" + } +} + +// wrapWithNullCheck wraps a condition with NULL and empty array checks. +// This ensures we don't match against NULL or empty JSON arrays. +func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string { + var nullCheck string + switch r.dialect { + case DialectSQLite: + nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr) + case DialectMySQL: + nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr) + case DialectPostgres: + nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr) + default: + return condition + } + return fmt.Sprintf("(%s AND %s)", condition, nullCheck) +} + func (r *renderer) jsonBoolPredicate(field Field) (string, error) { expr := jsonExtractExpr(r.dialect, field) switch r.dialect { case DialectSQLite: return fmt.Sprintf("%s IS TRUE", expr), nil case DialectMySQL: - return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil + return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil case DialectPostgres: return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil default: diff --git a/plugin/filter/schema.go b/plugin/filter/schema.go index c172eb62a..f2f8b0e4a 100644 --- a/plugin/filter/schema.go +++ b/plugin/filter/schema.go @@ -256,7 +256,7 @@ func NewAttachmentSchema() Schema { Name: "filename", Kind: FieldKindScalar, Type: FieldTypeString, - Column: Column{Table: "resource", Name: "filename"}, + Column: Column{Table: "attachment", Name: "filename"}, SupportsContains: true, Expressions: map[DialectName]string{}, }, @@ -264,14 +264,14 @@ func NewAttachmentSchema() Schema { Name: "mime_type", Kind: FieldKindScalar, Type: FieldTypeString, - Column: Column{Table: "resource", Name: "type"}, + Column: Column{Table: "attachment", Name: "type"}, Expressions: map[DialectName]string{}, }, "create_time": { Name: "create_time", Kind: FieldKindScalar, Type: FieldTypeTimestamp, - Column: Column{Table: "resource", Name: "created_ts"}, + Column: Column{Table: "attachment", Name: "created_ts"}, Expressions: map[DialectName]string{ // MySQL stores created_ts as TIMESTAMP, needs conversion to epoch DialectMySQL: "UNIX_TIMESTAMP(%s)", @@ -284,7 +284,7 @@ func NewAttachmentSchema() Schema { Name: "memo_id", Kind: FieldKindScalar, Type: FieldTypeInt, - Column: Column{Table: "resource", Name: "memo_id"}, + Column: Column{Table: "attachment", Name: "memo_id"}, Expressions: map[DialectName]string{}, AllowedComparisonOps: map[ComparisonOperator]bool{ CompareEq: true, diff --git a/proto/api/v1/instance_service.proto b/proto/api/v1/instance_service.proto index ebe9ed2f1..f4ce3d501 100644 --- a/proto/api/v1/instance_service.proto +++ b/proto/api/v1/instance_service.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package memos.api.v1; +import "api/v1/user_service.proto"; import "google/api/annotations.proto"; import "google/api/client.proto"; import "google/api/field_behavior.proto"; @@ -34,18 +35,18 @@ service InstanceService { // Instance profile message containing basic instance information. message InstanceProfile { - // The name of instance owner. - // Format: users/{user} - string owner = 1; - // Version is the current version of instance. string version = 2; - // Mode is the instance mode (e.g. "prod", "dev" or "demo"). - string mode = 3; + // Demo indicates if the instance is in demo mode. + bool demo = 3; // Instance URL is the URL of the instance. string instance_url = 6; + + // The first administrator who set up this instance. + // When null, instance requires initial setup (creating the first admin account). + User admin = 7; } // Request for instance profile. diff --git a/proto/api/v1/memo_service.proto b/proto/api/v1/memo_service.proto index fd3c4ac76..0a26b6011 100644 --- a/proto/api/v1/memo_service.proto +++ b/proto/api/v1/memo_service.proto @@ -173,11 +173,13 @@ message Memo { (google.api.resource_reference) = {type: "memos.api.v1/User"} ]; - // Output only. The creation timestamp. - google.protobuf.Timestamp create_time = 4 [(google.api.field_behavior) = OUTPUT_ONLY]; + // The creation timestamp. + // If not set on creation, the server will set it to the current time. + google.protobuf.Timestamp create_time = 4 [(google.api.field_behavior) = OPTIONAL]; - // Output only. The last update timestamp. - google.protobuf.Timestamp update_time = 5 [(google.api.field_behavior) = OUTPUT_ONLY]; + // The last update timestamp. + // If not set on creation, the server will set it to the current time. + google.protobuf.Timestamp update_time = 5 [(google.api.field_behavior) = OPTIONAL]; // The display timestamp of the memo. google.protobuf.Timestamp display_time = 6 [(google.api.field_behavior) = OPTIONAL]; diff --git a/proto/api/v1/user_service.proto b/proto/api/v1/user_service.proto index 4505451ec..53883acbb 100644 --- a/proto/api/v1/user_service.proto +++ b/proto/api/v1/user_service.proto @@ -203,13 +203,10 @@ message User { // User role enumeration. enum Role { - // Unspecified role. ROLE_UNSPECIFIED = 0; - // Host role with full system access. - HOST = 1; - // Admin role with administrative privileges. + // Admin role with system access. ADMIN = 2; - // Regular user role. + // User role with limited access. USER = 3; } } diff --git a/proto/gen/api/v1/instance_service.pb.go b/proto/gen/api/v1/instance_service.pb.go index c98eb3b88..5be2dd4bd 100644 --- a/proto/gen/api/v1/instance_service.pb.go +++ b/proto/gen/api/v1/instance_service.pb.go @@ -138,15 +138,15 @@ func (InstanceSetting_StorageSetting_StorageType) EnumDescriptor() ([]byte, []in // Instance profile message containing basic instance information. type InstanceProfile struct { state protoimpl.MessageState `protogen:"open.v1"` - // The name of instance owner. - // Format: users/{user} - Owner string `protobuf:"bytes,1,opt,name=owner,proto3" json:"owner,omitempty"` // Version is the current version of instance. Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` - // Mode is the instance mode (e.g. "prod", "dev" or "demo"). - Mode string `protobuf:"bytes,3,opt,name=mode,proto3" json:"mode,omitempty"` + // Demo indicates if the instance is in demo mode. + Demo bool `protobuf:"varint,3,opt,name=demo,proto3" json:"demo,omitempty"` // Instance URL is the URL of the instance. - InstanceUrl string `protobuf:"bytes,6,opt,name=instance_url,json=instanceUrl,proto3" json:"instance_url,omitempty"` + InstanceUrl string `protobuf:"bytes,6,opt,name=instance_url,json=instanceUrl,proto3" json:"instance_url,omitempty"` + // The first administrator who set up this instance. + // When null, instance requires initial setup (creating the first admin account). + Admin *User `protobuf:"bytes,7,opt,name=admin,proto3" json:"admin,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -181,13 +181,6 @@ func (*InstanceProfile) Descriptor() ([]byte, []int) { return file_api_v1_instance_service_proto_rawDescGZIP(), []int{0} } -func (x *InstanceProfile) GetOwner() string { - if x != nil { - return x.Owner - } - return "" -} - func (x *InstanceProfile) GetVersion() string { if x != nil { return x.Version @@ -195,11 +188,11 @@ func (x *InstanceProfile) GetVersion() string { return "" } -func (x *InstanceProfile) GetMode() string { +func (x *InstanceProfile) GetDemo() bool { if x != nil { - return x.Mode + return x.Demo } - return "" + return false } func (x *InstanceProfile) GetInstanceUrl() string { @@ -209,6 +202,13 @@ func (x *InstanceProfile) GetInstanceUrl() string { return "" } +func (x *InstanceProfile) GetAdmin() *User { + if x != nil { + return x.Admin + } + return nil +} + // Request for instance profile. type GetInstanceProfileRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -875,12 +875,12 @@ var File_api_v1_instance_service_proto protoreflect.FileDescriptor const file_api_v1_instance_service_proto_rawDesc = "" + "\n" + - "\x1dapi/v1/instance_service.proto\x12\fmemos.api.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a google/protobuf/field_mask.proto\"x\n" + - "\x0fInstanceProfile\x12\x14\n" + - "\x05owner\x18\x01 \x01(\tR\x05owner\x12\x18\n" + + "\x1dapi/v1/instance_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a google/protobuf/field_mask.proto\"\x8c\x01\n" + + "\x0fInstanceProfile\x12\x18\n" + "\aversion\x18\x02 \x01(\tR\aversion\x12\x12\n" + - "\x04mode\x18\x03 \x01(\tR\x04mode\x12!\n" + - "\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\"\x1b\n" + + "\x04demo\x18\x03 \x01(\bR\x04demo\x12!\n" + + "\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\x12(\n" + + "\x05admin\x18\a \x01(\v2\x12.memos.api.v1.UserR\x05admin\"\x1b\n" + "\x19GetInstanceProfileRequest\"\x99\x0f\n" + "\x0fInstanceSetting\x12\x17\n" + "\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12W\n" + @@ -970,28 +970,30 @@ var file_api_v1_instance_service_proto_goTypes = []any{ (*InstanceSetting_MemoRelatedSetting)(nil), // 9: memos.api.v1.InstanceSetting.MemoRelatedSetting (*InstanceSetting_GeneralSetting_CustomProfile)(nil), // 10: memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile (*InstanceSetting_StorageSetting_S3Config)(nil), // 11: memos.api.v1.InstanceSetting.StorageSetting.S3Config - (*fieldmaskpb.FieldMask)(nil), // 12: google.protobuf.FieldMask + (*User)(nil), // 12: memos.api.v1.User + (*fieldmaskpb.FieldMask)(nil), // 13: google.protobuf.FieldMask } var file_api_v1_instance_service_proto_depIdxs = []int32{ - 7, // 0: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting - 8, // 1: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting - 9, // 2: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting - 4, // 3: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting - 12, // 4: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask - 10, // 5: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile - 1, // 6: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType - 11, // 7: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config - 3, // 8: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest - 5, // 9: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest - 6, // 10: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest - 2, // 11: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile - 4, // 12: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting - 4, // 13: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting - 11, // [11:14] is the sub-list for method output_type - 8, // [8:11] is the sub-list for method input_type - 8, // [8:8] is the sub-list for extension type_name - 8, // [8:8] is the sub-list for extension extendee - 0, // [0:8] is the sub-list for field type_name + 12, // 0: memos.api.v1.InstanceProfile.admin:type_name -> memos.api.v1.User + 7, // 1: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting + 8, // 2: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting + 9, // 3: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting + 4, // 4: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting + 13, // 5: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask + 10, // 6: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile + 1, // 7: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType + 11, // 8: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config + 3, // 9: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest + 5, // 10: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest + 6, // 11: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest + 2, // 12: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile + 4, // 13: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting + 4, // 14: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting + 12, // [12:15] is the sub-list for method output_type + 9, // [9:12] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_api_v1_instance_service_proto_init() } @@ -999,6 +1001,7 @@ func file_api_v1_instance_service_proto_init() { if File_api_v1_instance_service_proto != nil { return } + file_api_v1_user_service_proto_init() file_api_v1_instance_service_proto_msgTypes[2].OneofWrappers = []any{ (*InstanceSetting_GeneralSetting_)(nil), (*InstanceSetting_StorageSetting_)(nil), diff --git a/proto/gen/api/v1/memo_service.pb.go b/proto/gen/api/v1/memo_service.pb.go index e102f9e51..b99fab571 100644 --- a/proto/gen/api/v1/memo_service.pb.go +++ b/proto/gen/api/v1/memo_service.pb.go @@ -222,9 +222,11 @@ type Memo struct { // The name of the creator. // Format: users/{user} Creator string `protobuf:"bytes,3,opt,name=creator,proto3" json:"creator,omitempty"` - // Output only. The creation timestamp. + // The creation timestamp. + // If not set on creation, the server will set it to the current time. CreateTime *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty"` - // Output only. The last update timestamp. + // The last update timestamp. + // If not set on creation, the server will set it to the current time. UpdateTime *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"` // The display timestamp of the memo. DisplayTime *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=display_time,json=displayTime,proto3" json:"display_time,omitempty"` @@ -1816,9 +1818,9 @@ const file_api_v1_memo_service_proto_rawDesc = "" + "\x05state\x18\x02 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x02R\x05state\x123\n" + "\acreator\x18\x03 \x01(\tB\x19\xe0A\x03\xfaA\x13\n" + "\x11memos.api.v1/UserR\acreator\x12@\n" + - "\vcreate_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" + + "\vcreate_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\n" + "createTime\x12@\n" + - "\vupdate_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" + + "\vupdate_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\n" + "updateTime\x12B\n" + "\fdisplay_time\x18\x06 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\vdisplayTime\x12\x1d\n" + "\acontent\x18\a \x01(\tB\x03\xe0A\x02R\acontent\x12=\n" + diff --git a/proto/gen/api/v1/user_service.pb.go b/proto/gen/api/v1/user_service.pb.go index 82d76ad12..9f2d32e0d 100644 --- a/proto/gen/api/v1/user_service.pb.go +++ b/proto/gen/api/v1/user_service.pb.go @@ -29,13 +29,10 @@ const ( type User_Role int32 const ( - // Unspecified role. User_ROLE_UNSPECIFIED User_Role = 0 - // Host role with full system access. - User_HOST User_Role = 1 - // Admin role with administrative privileges. + // Admin role with system access. User_ADMIN User_Role = 2 - // Regular user role. + // User role with limited access. User_USER User_Role = 3 ) @@ -43,13 +40,11 @@ const ( var ( User_Role_name = map[int32]string{ 0: "ROLE_UNSPECIFIED", - 1: "HOST", 2: "ADMIN", 3: "USER", } User_Role_value = map[string]int32{ "ROLE_UNSPECIFIED": 0, - "HOST": 1, "ADMIN": 2, "USER": 3, } @@ -2509,7 +2504,7 @@ var File_api_v1_user_service_proto protoreflect.FileDescriptor const file_api_v1_user_service_proto_rawDesc = "" + "\n" + - "\x19api/v1/user_service.proto\x12\fmemos.api.v1\x1a\x13api/v1/common.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xcb\x04\n" + + "\x19api/v1/user_service.proto\x12\fmemos.api.v1\x1a\x13api/v1/common.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xc1\x04\n" + "\x04User\x12\x17\n" + "\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x120\n" + "\x04role\x18\x02 \x01(\x0e2\x17.memos.api.v1.User.RoleB\x03\xe0A\x02R\x04role\x12\x1f\n" + @@ -2525,10 +2520,9 @@ const file_api_v1_user_service_proto_rawDesc = "" + " \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" + "createTime\x12@\n" + "\vupdate_time\x18\v \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" + - "updateTime\";\n" + + "updateTime\"1\n" + "\x04Role\x12\x14\n" + - "\x10ROLE_UNSPECIFIED\x10\x00\x12\b\n" + - "\x04HOST\x10\x01\x12\t\n" + + "\x10ROLE_UNSPECIFIED\x10\x00\x12\t\n" + "\x05ADMIN\x10\x02\x12\b\n" + "\x04USER\x10\x03:7\xeaA4\n" + "\x11memos.api.v1/User\x12\fusers/{user}\x1a\x04name*\x05users2\x04user\"\x9d\x01\n" + diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index b52c35b20..d8ed9eed1 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -2132,20 +2132,21 @@ components: InstanceProfile: type: object properties: - owner: - type: string - description: |- - The name of instance owner. - Format: users/{user} version: type: string description: Version is the current version of instance. - mode: - type: string - description: Mode is the instance mode (e.g. "prod", "dev" or "demo"). + demo: + type: boolean + description: Demo indicates if the instance is in demo mode. instanceUrl: type: string description: Instance URL is the URL of the instance. + admin: + allOf: + - $ref: '#/components/schemas/User' + description: |- + The first administrator who set up this instance. + When null, instance requires initial setup (creating the first admin account). description: Instance profile message containing basic instance information. InstanceSetting: type: object @@ -2470,14 +2471,16 @@ components: The name of the creator. Format: users/{user} createTime: - readOnly: true type: string - description: Output only. The creation timestamp. + description: |- + The creation timestamp. + If not set on creation, the server will set it to the current time. format: date-time updateTime: - readOnly: true type: string - description: Output only. The last update timestamp. + description: |- + The last update timestamp. + If not set on creation, the server will set it to the current time. format: date-time displayTime: type: string @@ -2857,7 +2860,6 @@ components: role: enum: - ROLE_UNSPECIFIED - - HOST - ADMIN - USER type: string diff --git a/scripts/Dockerfile b/scripts/Dockerfile index c58a6347c..ed2894c2b 100644 --- a/scripts/Dockerfile +++ b/scripts/Dockerfile @@ -1,30 +1,56 @@ -FROM golang:1.25-alpine AS backend +FROM --platform=$BUILDPLATFORM golang:1.25-alpine AS backend WORKDIR /backend-build + +# Install build dependencies +RUN apk add --no-cache git ca-certificates + +# Copy go mod files and download dependencies (cached layer) COPY go.mod go.sum ./ -RUN go mod download +RUN --mount=type=cache,target=/go/pkg/mod \ + go mod download + +# Copy source code (use .dockerignore to exclude unnecessary files) COPY . . + # Please build frontend first, so that the static files are available. # Refer to `pnpm release` in package.json for the build command. +ARG TARGETOS TARGETARCH VERSION COMMIT RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -ldflags="-s -w" -o memos ./cmd/memos + CGO_ENABLED=0 GOOS=$TARGETOS GOARCH=$TARGETARCH \ + go build \ + -trimpath \ + -ldflags="-s -w -extldflags '-static'" \ + -tags netgo,osusergo \ + -o memos \ + ./cmd/memos -FROM alpine:latest AS monolithic -WORKDIR /usr/local/memos +# Use minimal Alpine with security updates +FROM alpine:3.21 AS monolithic -RUN apk add --no-cache tzdata -ENV TZ="UTC" +# Install runtime dependencies and create non-root user in single layer +RUN apk add --no-cache tzdata ca-certificates su-exec && \ + addgroup -g 10001 -S nonroot && \ + adduser -u 10001 -S -G nonroot -h /var/opt/memos nonroot && \ + mkdir -p /var/opt/memos /usr/local/memos && \ + chown -R nonroot:nonroot /var/opt/memos -COPY --from=backend /backend-build/memos /usr/local/memos/ -COPY ./scripts/entrypoint.sh /usr/local/memos/ +# Copy binary and entrypoint to /usr/local/memos +COPY --from=backend /backend-build/memos /usr/local/memos/memos +COPY --from=backend --chmod=755 /backend-build/scripts/entrypoint.sh /usr/local/memos/entrypoint.sh + +# Run as root to fix permissions, entrypoint will drop to nonroot +USER root + +# Set working directory to the writable volume +WORKDIR /var/opt/memos + +# Data directory +VOLUME /var/opt/memos + +ENV TZ="UTC" \ + MEMOS_PORT="5230" EXPOSE 5230 -# Directory to store the data, which can be referenced as the mounting point. -RUN mkdir -p /var/opt/memos -VOLUME /var/opt/memos - -ENV MEMOS_MODE="prod" -ENV MEMOS_PORT="5230" - -ENTRYPOINT ["./entrypoint.sh", "./memos"] +ENTRYPOINT ["/usr/local/memos/entrypoint.sh", "/usr/local/memos/memos"] diff --git a/scripts/build.sh b/scripts/build.sh index ef687d061..15fbe2b71 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -29,4 +29,4 @@ go build -o "$OUTPUT" ./cmd/memos echo "Build successful!" echo "To run the application, execute the following command:" -echo "$OUTPUT --mode dev" +echo "$OUTPUT" diff --git a/scripts/compose.yaml b/scripts/compose.yaml index d5e33c3b0..a8a746c48 100644 --- a/scripts/compose.yaml +++ b/scripts/compose.yaml @@ -1,6 +1,6 @@ services: memos: - image: neosmemo/memos:latest + image: neosmemo/memos:stable container_name: memos volumes: - ~/.memos/:/var/opt/memos diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index ec62af83d..710469df9 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -1,5 +1,19 @@ #!/usr/bin/env sh +# Fix ownership of data directory for users upgrading from older versions +# where files were created as root +MEMOS_UID=${MEMOS_UID:-10001} +MEMOS_GID=${MEMOS_GID:-10001} +DATA_DIR="/var/opt/memos" + +if [ "$(id -u)" = "0" ]; then + # Running as root, fix permissions and drop to nonroot + if [ -d "$DATA_DIR" ]; then + chown -R "$MEMOS_UID:$MEMOS_GID" "$DATA_DIR" 2>/dev/null || true + fi + exec su-exec "$MEMOS_UID:$MEMOS_GID" "$0" "$@" +fi + file_env() { var="$1" fileVar="${var}_FILE" diff --git a/server/auth/token_test.go b/server/auth/token_test.go index 3b4262dd1..3932016ec 100644 --- a/server/auth/token_test.go +++ b/server/auth/token_test.go @@ -74,7 +74,7 @@ func TestParseAccessTokenV2(t *testing.T) { }) t.Run("parses token with different roles", func(t *testing.T) { - roles := []string{"USER", "ADMIN", "HOST"} + roles := []string{"USER", "ADMIN"} for _, role := range roles { token, _, err := GenerateAccessTokenV2(1, "testuser", role, "ACTIVE", secret) require.NoError(t, err) diff --git a/server/router/api/v1/acl_config.go b/server/router/api/v1/acl_config.go index 4045a7f5c..9958900b2 100644 --- a/server/router/api/v1/acl_config.go +++ b/server/router/api/v1/acl_config.go @@ -29,8 +29,9 @@ var PublicMethods = map[string]struct{}{ "/memos.api.v1.IdentityProviderService/ListIdentityProviders": {}, // Memo Service - public memos (visibility filtering done in service layer) - "/memos.api.v1.MemoService/GetMemo": {}, - "/memos.api.v1.MemoService/ListMemos": {}, + "/memos.api.v1.MemoService/GetMemo": {}, + "/memos.api.v1.MemoService/ListMemos": {}, + "/memos.api.v1.MemoService/ListMemoComments": {}, } // IsPublicMethod checks if a procedure path is public (no authentication required). diff --git a/server/router/api/v1/attachment_exif_test.go b/server/router/api/v1/attachment_exif_test.go new file mode 100644 index 000000000..2b0d5d7c6 --- /dev/null +++ b/server/router/api/v1/attachment_exif_test.go @@ -0,0 +1,191 @@ +package v1 + +import ( + "bytes" + "image" + "image/color" + "image/jpeg" + "testing" + + "github.com/disintegration/imaging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShouldStripExif(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mimeType string + expected bool + }{ + { + name: "JPEG should strip EXIF", + mimeType: "image/jpeg", + expected: true, + }, + { + name: "JPG should strip EXIF", + mimeType: "image/jpg", + expected: true, + }, + { + name: "TIFF should strip EXIF", + mimeType: "image/tiff", + expected: true, + }, + { + name: "WebP should strip EXIF", + mimeType: "image/webp", + expected: true, + }, + { + name: "HEIC should strip EXIF", + mimeType: "image/heic", + expected: true, + }, + { + name: "HEIF should strip EXIF", + mimeType: "image/heif", + expected: true, + }, + { + name: "PNG should not strip EXIF", + mimeType: "image/png", + expected: false, + }, + { + name: "GIF should not strip EXIF", + mimeType: "image/gif", + expected: false, + }, + { + name: "text file should not strip EXIF", + mimeType: "text/plain", + expected: false, + }, + { + name: "PDF should not strip EXIF", + mimeType: "application/pdf", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := shouldStripExif(tt.mimeType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStripImageExif(t *testing.T) { + t.Parallel() + + // Create a simple test image + img := image.NewRGBA(image.Rect(0, 0, 100, 100)) + // Fill with red color + for y := 0; y < 100; y++ { + for x := 0; x < 100; x++ { + img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255}) + } + } + + // Encode as JPEG + var buf bytes.Buffer + err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90}) + require.NoError(t, err) + originalData := buf.Bytes() + + t.Run("strip JPEG metadata", func(t *testing.T) { + t.Parallel() + + strippedData, err := stripImageExif(originalData, "image/jpeg") + require.NoError(t, err) + assert.NotEmpty(t, strippedData) + + // Verify it's still a valid image + decodedImg, err := imaging.Decode(bytes.NewReader(strippedData)) + require.NoError(t, err) + assert.Equal(t, 100, decodedImg.Bounds().Dx()) + assert.Equal(t, 100, decodedImg.Bounds().Dy()) + }) + + t.Run("strip JPG metadata (alternate extension)", func(t *testing.T) { + t.Parallel() + + strippedData, err := stripImageExif(originalData, "image/jpg") + require.NoError(t, err) + assert.NotEmpty(t, strippedData) + + // Verify it's still a valid image + decodedImg, err := imaging.Decode(bytes.NewReader(strippedData)) + require.NoError(t, err) + assert.NotNil(t, decodedImg) + }) + + t.Run("strip PNG metadata", func(t *testing.T) { + t.Parallel() + + // Encode as PNG first + var pngBuf bytes.Buffer + err := imaging.Encode(&pngBuf, img, imaging.PNG) + require.NoError(t, err) + + strippedData, err := stripImageExif(pngBuf.Bytes(), "image/png") + require.NoError(t, err) + assert.NotEmpty(t, strippedData) + + // Verify it's still a valid image + decodedImg, err := imaging.Decode(bytes.NewReader(strippedData)) + require.NoError(t, err) + assert.Equal(t, 100, decodedImg.Bounds().Dx()) + assert.Equal(t, 100, decodedImg.Bounds().Dy()) + }) + + t.Run("handle WebP format by converting to JPEG", func(t *testing.T) { + t.Parallel() + + // WebP format will be converted to JPEG + strippedData, err := stripImageExif(originalData, "image/webp") + require.NoError(t, err) + assert.NotEmpty(t, strippedData) + + // Verify it's a valid image + decodedImg, err := imaging.Decode(bytes.NewReader(strippedData)) + require.NoError(t, err) + assert.NotNil(t, decodedImg) + }) + + t.Run("handle HEIC format by converting to JPEG", func(t *testing.T) { + t.Parallel() + + strippedData, err := stripImageExif(originalData, "image/heic") + require.NoError(t, err) + assert.NotEmpty(t, strippedData) + + // Verify it's a valid image + decodedImg, err := imaging.Decode(bytes.NewReader(strippedData)) + require.NoError(t, err) + assert.NotNil(t, decodedImg) + }) + + t.Run("return error for invalid image data", func(t *testing.T) { + t.Parallel() + + invalidData := []byte("not an image") + _, err := stripImageExif(invalidData, "image/jpeg") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode image") + }) + + t.Run("return error for empty image data", func(t *testing.T) { + t.Parallel() + + emptyData := []byte{} + _, err := stripImageExif(emptyData, "image/jpeg") + assert.Error(t, err) + }) +} diff --git a/server/router/api/v1/attachment_service.go b/server/router/api/v1/attachment_service.go index a87140f75..f2218da41 100644 --- a/server/router/api/v1/attachment_service.go +++ b/server/router/api/v1/attachment_service.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "fmt" "io" + "log/slog" "mime" "net/http" "os" @@ -14,6 +15,7 @@ import ( "strings" "time" + "github.com/disintegration/imaging" "github.com/lithammer/shortuuid/v4" "github.com/pkg/errors" "google.golang.org/grpc/codes" @@ -38,6 +40,10 @@ const ( MebiByte = 1024 * 1024 // ThumbnailCacheFolder is the folder name where the thumbnail images are stored. ThumbnailCacheFolder = ".thumbnail_cache" + + // defaultJPEGQuality is the JPEG quality used when re-encoding images for EXIF stripping. + // Quality 95 maintains visual quality while ensuring metadata is removed. + defaultJPEGQuality = 95 ) var SupportedThumbnailMimeTypes = []string{ @@ -45,6 +51,17 @@ var SupportedThumbnailMimeTypes = []string{ "image/jpeg", } +// exifCapableImageTypes defines image formats that may contain EXIF metadata. +// These formats will have their EXIF metadata stripped on upload for privacy. +var exifCapableImageTypes = map[string]bool{ + "image/jpeg": true, + "image/jpg": true, + "image/tiff": true, + "image/webp": true, + "image/heic": true, + "image/heif": true, +} + func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) { user, err := s.fetchCurrentUser(ctx) if err != nil { @@ -111,6 +128,21 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat create.Size = int64(size) create.Blob = request.Attachment.Content + // Strip EXIF metadata from images for privacy protection. + // This removes sensitive information like GPS location, device details, etc. + if shouldStripExif(create.Type) { + if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil { + // Log warning but continue with original image to ensure uploads don't fail. + slog.Warn("failed to strip EXIF metadata from image", + slog.String("type", create.Type), + slog.String("filename", create.Filename), + slog.String("error", err.Error())) + } else { + create.Blob = strippedBlob + create.Size = int64(len(strippedBlob)) + } + } + if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil { return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err) } @@ -214,6 +246,12 @@ func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttac if attachment == nil { return nil, status.Errorf(codes.NotFound, "attachment not found") } + + // Check access permission based on linked memo visibility. + if err := s.checkAttachmentAccess(ctx, attachment); err != nil { + return nil, err + } + return convertAttachmentFromStore(attachment), nil } @@ -225,10 +263,24 @@ func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.Updat if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 { return nil, status.Errorf(codes.InvalidArgument, "update mask is required") } + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err) } + if attachment == nil { + return nil, status.Errorf(codes.NotFound, "attachment not found") + } + // Only the creator or admin can update the attachment. + if attachment.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } currentTs := time.Now().Unix() update := &store.UpdateAttachment{ @@ -516,3 +568,89 @@ func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr s } return nil } + +// checkAttachmentAccess verifies the user has permission to access the attachment. +// For unlinked attachments (no memo), only the creator can access. +// For linked attachments, access follows the memo's visibility rules. +func (s *APIV1Service) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment) error { + user, _ := s.fetchCurrentUser(ctx) + + // For unlinked attachments, only the creator can access. + if attachment.MemoID == nil { + if user == nil { + return status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if attachment.CreatorID != user.ID && !isSuperUser(user) { + return status.Errorf(codes.PermissionDenied, "permission denied") + } + return nil + } + + // For linked attachments, check memo visibility. + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID}) + if err != nil { + return status.Errorf(codes.Internal, "failed to get memo: %v", err) + } + if memo == nil { + return status.Errorf(codes.NotFound, "memo not found") + } + + if memo.Visibility == store.Public { + return nil + } + if user == nil { + return status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) { + return status.Errorf(codes.PermissionDenied, "permission denied") + } + return nil +} + +// shouldStripExif checks if the MIME type is an image format that may contain EXIF metadata. +// Returns true for formats like JPEG, TIFF, WebP, HEIC, and HEIF which commonly contain +// privacy-sensitive metadata such as GPS coordinates, camera settings, and device information. +func shouldStripExif(mimeType string) bool { + return exifCapableImageTypes[mimeType] +} + +// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them. +// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps. +// +// The function preserves the correct image orientation by applying EXIF orientation tags +// during decoding before stripping all metadata. Images are re-encoded with high quality +// to minimize visual degradation. +// +// Supported formats: +// - JPEG/JPG: Re-encoded as JPEG with quality 95 +// - PNG: Re-encoded as PNG (lossless) +// - TIFF/WebP/HEIC/HEIF: Re-encoded as JPEG with quality 95 +// +// Returns the cleaned image data without any EXIF metadata, or an error if processing fails. +func stripImageExif(imageData []byte, mimeType string) ([]byte, error) { + // Decode image with automatic EXIF orientation correction. + // This ensures the image displays correctly after metadata removal. + img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true)) + if err != nil { + return nil, errors.Wrap(err, "failed to decode image") + } + + // Re-encode the image without EXIF metadata. + var buf bytes.Buffer + var encodeErr error + + if mimeType == "image/png" { + // Preserve PNG format for lossless encoding + encodeErr = imaging.Encode(&buf, img, imaging.PNG) + } else { + // For JPEG, TIFF, WebP, HEIC, HEIF - re-encode as JPEG. + // This ensures EXIF is stripped and provides good compression. + encodeErr = imaging.Encode(&buf, img, imaging.JPEG, imaging.JPEGQuality(defaultJPEGQuality)) + } + + if encodeErr != nil { + return nil, errors.Wrap(encodeErr, "failed to encode image") + } + + return buf.Bytes(), nil +} diff --git a/server/router/api/v1/common.go b/server/router/api/v1/common.go index 66fc032e5..be7bfa292 100644 --- a/server/router/api/v1/common.go +++ b/server/router/api/v1/common.go @@ -64,5 +64,5 @@ func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error { } func isSuperUser(user *store.User) bool { - return user.Role == store.RoleAdmin || user.Role == store.RoleHost + return user.Role == store.RoleAdmin } diff --git a/server/router/api/v1/connect_interceptors.go b/server/router/api/v1/connect_interceptors.go index 03eb35de4..dab7150d9 100644 --- a/server/router/api/v1/connect_interceptors.go +++ b/server/router/api/v1/connect_interceptors.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "reflect" "runtime/debug" "connectrpc.com/connect" @@ -50,10 +51,30 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc // Set metadata in context so services can use metadata.FromIncomingContext() ctx = metadata.NewIncomingContext(ctx, md) - return next(ctx, req) + + // Execute the request + resp, err := next(ctx, req) + + // Prevent browser caching of API responses to avoid stale data issues + // See: https://github.com/usememos/memos/issues/5470 + if !isNilAnyResponse(resp) && resp.Header() != nil { + resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + resp.Header().Set("Pragma", "no-cache") + resp.Header().Set("Expires", "0") + } + + return resp, err } } +func isNilAnyResponse(resp connect.AnyResponse) bool { + if resp == nil { + return true + } + val := reflect.ValueOf(resp) + return val.Kind() == reflect.Ptr && val.IsNil() +} + func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { return next } diff --git a/server/router/api/v1/idp_service.go b/server/router/api/v1/idp_service.go index 2b48d2c10..d257a49b5 100644 --- a/server/router/api/v1/idp_service.go +++ b/server/router/api/v1/idp_service.go @@ -18,7 +18,10 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } - if currentUser == nil || currentUser.Role != store.RoleHost { + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -84,7 +87,10 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } - if currentUser == nil || currentUser.Role != store.RoleHost { + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -125,7 +131,10 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } - if currentUser == nil || currentUser.Role != store.RoleHost { + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -219,7 +228,7 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv } func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider { - if userRole != store.RoleHost { + if userRole != store.RoleAdmin { if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 { identityProvider.Config.GetOauth2Config().ClientSecret = "" } diff --git a/server/router/api/v1/instance_service.go b/server/router/api/v1/instance_service.go index 82830112e..520862771 100644 --- a/server/router/api/v1/instance_service.go +++ b/server/router/api/v1/instance_service.go @@ -15,17 +15,16 @@ import ( // GetInstanceProfile returns the instance profile. func (s *APIV1Service) GetInstanceProfile(ctx context.Context, _ *v1pb.GetInstanceProfileRequest) (*v1pb.InstanceProfile, error) { + admin, err := s.GetInstanceAdmin(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get instance admin: %v", err) + } + instanceProfile := &v1pb.InstanceProfile{ Version: s.Profile.Version, - Mode: s.Profile.Mode, + Demo: s.Profile.Demo, InstanceUrl: s.Profile.InstanceURL, - } - owner, err := s.GetInstanceOwner(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err) - } - if owner != nil { - instanceProfile.Owner = owner.Name + Admin: admin, // nil when not initialized } return instanceProfile, nil } @@ -64,13 +63,16 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get return nil, status.Errorf(codes.NotFound, "instance setting not found") } - // For storage setting, only host can get it. + // For storage setting, only admin can get it. if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE { user, err := s.fetchCurrentUser(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } - if user == nil || user.Role != store.RoleHost { + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if user.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } } @@ -86,7 +88,7 @@ func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb. if user == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - if user.Role != store.RoleHost { + if user.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -267,24 +269,17 @@ func convertInstanceMemoRelatedSettingToStore(setting *v1pb.InstanceSetting_Memo } } -var ownerCache *v1pb.User - -func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) { - if ownerCache != nil { - return ownerCache, nil - } - - hostUserType := store.RoleHost +func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error) { + adminUserType := store.RoleAdmin user, err := s.Store.GetUser(ctx, &store.FindUser{ - Role: &hostUserType, + Role: &adminUserType, }) if err != nil { - return nil, errors.Wrapf(err, "failed to find owner") + return nil, errors.Wrapf(err, "failed to find admin") } if user == nil { return nil, nil } - ownerCache = convertUserFromStore(user) - return ownerCache, nil + return convertUserFromStore(user), nil } diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index fead95d7d..153f1aa80 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -98,6 +98,24 @@ func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.Li if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err) } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } + + // Check memo visibility. + if memo.Visibility != store.Public { + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + } + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ MemoID: &memo.ID, }) diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index f407b009e..f5d250a16 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -45,10 +45,35 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR Content: request.Memo.Content, Visibility: convertVisibilityToStore(request.Memo.Visibility), } + instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting") } + + // Handle display_time first: if provided, use it to set the appropriate timestamp + // based on the instance setting (similar to UpdateMemo logic) + // Note: explicit create_time/update_time below will override this if provided + if request.Memo.DisplayTime != nil && request.Memo.DisplayTime.IsValid() { + displayTs := request.Memo.DisplayTime.AsTime().Unix() + if instanceMemoRelatedSetting.DisplayWithUpdateTime { + create.UpdatedTs = displayTs + } else { + create.CreatedTs = displayTs + } + } + + // Set custom timestamps if provided in the request + // These take precedence over display_time + if request.Memo.CreateTime != nil && request.Memo.CreateTime.IsValid() { + createdTs := request.Memo.CreateTime.AsTime().Unix() + create.CreatedTs = createdTs + } + if request.Memo.UpdateTime != nil && request.Memo.UpdateTime.IsValid() { + updatedTs := request.Memo.UpdateTime.AsTime().Unix() + create.UpdatedTs = updatedTs + } + if instanceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public { return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled") } @@ -281,7 +306,7 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest return nil, status.Errorf(codes.Internal, "failed to get user") } if user == nil { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } if memo.Visibility == store.Private && memo.CreatorID != user.ID { return nil, status.Errorf(codes.PermissionDenied, "permission denied") @@ -497,23 +522,7 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR } } - if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete memo") - } - - // Delete memo relation - if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{MemoID: &memo.ID}); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete memo relations") - } - - // Delete related attachments. - for _, attachment := range attachments { - if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete attachment") - } - } - - // Delete memo comments + // Delete memo comments first (store.DeleteMemo handles their relations and attachments) commentType := store.MemoRelationComment relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType}) if err != nil { @@ -525,10 +534,9 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR } } - // Delete memo references - referenceType := store.MemoRelationReference - if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{RelatedMemoID: &memo.ID, Type: &referenceType}); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete memo references") + // Delete the memo (store.DeleteMemo handles relation and attachment cleanup) + if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { + return nil, status.Errorf(codes.Internal, "failed to delete memo") } return &emptypb.Empty{}, nil @@ -543,6 +551,21 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + if relatedMemo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } + + // Check memo visibility before allowing comment. + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user") + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } // Create the memo comment first. memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{ diff --git a/server/router/api/v1/reaction_service.go b/server/router/api/v1/reaction_service.go index 872ececde..a7c7cc3bd 100644 --- a/server/router/api/v1/reaction_service.go +++ b/server/router/api/v1/reaction_service.go @@ -15,6 +15,33 @@ import ( ) func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) { + // Extract memo UID and check visibility. + memoUID, err := ExtractMemoUIDFromName(request.Name) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) + } + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err) + } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } + + // Check memo visibility. + if memo.Visibility != store.Public { + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user") + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + } + reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ ContentID: &request.Name, }) @@ -40,6 +67,25 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups if user == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } + + // Extract memo UID and check visibility before allowing reaction. + memoUID, err := ExtractMemoUIDFromName(request.Reaction.ContentId) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) + } + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err) + } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } + + // Check memo visibility. + if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{ CreatorID: user.ID, ContentID: request.Reaction.ContentId, diff --git a/server/router/api/v1/test/idp_service_test.go b/server/router/api/v1/test/idp_service_test.go index d60d42de0..302a2737e 100644 --- a/server/router/api/v1/test/idp_service_test.go +++ b/server/router/api/v1/test/idp_service_test.go @@ -97,7 +97,7 @@ func TestCreateIdentityProvider(t *testing.T) { _, err := ts.Service.CreateIdentityProvider(ctx, req) require.Error(t, err) - require.Contains(t, err.Error(), "permission denied") + require.Contains(t, err.Error(), "user not authenticated") }) } @@ -547,6 +547,6 @@ func TestIdentityProviderPermissions(t *testing.T) { _, err := ts.Service.CreateIdentityProvider(ctx, req) require.Error(t, err) - require.Contains(t, err.Error(), "permission denied") + require.Contains(t, err.Error(), "user not authenticated") }) } diff --git a/server/router/api/v1/test/instance_admin_cache_test.go b/server/router/api/v1/test/instance_admin_cache_test.go new file mode 100644 index 000000000..5fa217160 --- /dev/null +++ b/server/router/api/v1/test/instance_admin_cache_test.go @@ -0,0 +1,54 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestInstanceAdminRetrieval(t *testing.T) { + ctx := context.Background() + + t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) { + // Create test service + ts := NewTestService(t) + defer ts.Cleanup() + + // Verify instance is not initialized initially + profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{}) + require.NoError(t, err) + require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user") + + // Create the first admin user + user, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + require.NotNil(t, user) + + // Verify instance is now initialized + profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{}) + require.NoError(t, err) + require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created") + require.Equal(t, user.Username, profile2.Admin.Username) + }) + + t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) { + // Create test service + ts := NewTestService(t) + defer ts.Cleanup() + + // Create admin user + user, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Multiple calls should return consistent admin user (from cache) + for i := 0; i < 5; i++ { + profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{}) + require.NoError(t, err) + require.NotNil(t, profile.Admin) + require.Equal(t, user.Username, profile.Admin.Username) + } + }) +} diff --git a/server/router/api/v1/test/instance_service_test.go b/server/router/api/v1/test/instance_service_test.go index 1422907a3..2043cf8b6 100644 --- a/server/router/api/v1/test/instance_service_test.go +++ b/server/router/api/v1/test/instance_service_test.go @@ -2,7 +2,6 @@ package test import ( "context" - "fmt" "testing" "github.com/stretchr/testify/require" @@ -28,14 +27,14 @@ func TestGetInstanceProfile(t *testing.T) { // Verify the response contains expected data require.Equal(t, "test-1.0.0", resp.Version) - require.Equal(t, "dev", resp.Mode) + require.True(t, resp.Demo) require.Equal(t, "http://localhost:8080", resp.InstanceUrl) - // Owner should be empty since no users are created - require.Empty(t, resp.Owner) + // Instance should not be initialized since no admin users are created + require.Nil(t, resp.Admin) }) - t.Run("GetInstanceProfile with owner", func(t *testing.T) { + t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) { // Create test service for this specific test ts := NewTestService(t) defer ts.Cleanup() @@ -53,14 +52,14 @@ func TestGetInstanceProfile(t *testing.T) { require.NoError(t, err) require.NotNil(t, resp) - // Verify the response contains expected data including owner + // Verify the response contains expected data with initialized flag require.Equal(t, "test-1.0.0", resp.Version) - require.Equal(t, "dev", resp.Mode) + require.True(t, resp.Demo) require.Equal(t, "http://localhost:8080", resp.InstanceUrl) - // User name should be "users/{id}" format where id is the user's ID - expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID) - require.Equal(t, expectedOwnerName, resp.Owner) + // Instance should be initialized since an admin user exists + require.NotNil(t, resp.Admin) + require.Equal(t, hostUser.Username, resp.Admin.Username) }) } @@ -73,9 +72,8 @@ func TestGetInstanceProfile_Concurrency(t *testing.T) { defer ts.Cleanup() // Create a host user - hostUser, err := ts.CreateHostUser(ctx, "admin") + _, err := ts.CreateHostUser(ctx, "admin") require.NoError(t, err) - expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID) // Make concurrent requests numGoroutines := 10 @@ -102,9 +100,9 @@ func TestGetInstanceProfile_Concurrency(t *testing.T) { case resp := <-results: require.NotNil(t, resp) require.Equal(t, "test-1.0.0", resp.Version) - require.Equal(t, "dev", resp.Mode) + require.True(t, resp.Demo) require.Equal(t, "http://localhost:8080", resp.InstanceUrl) - require.Equal(t, expectedOwnerName, resp.Owner) + require.NotNil(t, resp.Admin) } } }) diff --git a/server/router/api/v1/test/memo_service_test.go b/server/router/api/v1/test/memo_service_test.go index e9cb1d3bd..a88eb0258 100644 --- a/server/router/api/v1/test/memo_service_test.go +++ b/server/router/api/v1/test/memo_service_test.go @@ -5,8 +5,10 @@ import ( "fmt" "slices" "testing" + "time" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" apiv1 "github.com/usememos/memos/proto/gen/api/v1" ) @@ -250,3 +252,118 @@ func TestListMemos(t *testing.T) { require.NotNil(t, userTwoReaction) require.Equal(t, "👍", userTwoReaction.ReactionType) } + +// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments. +// This addresses issue #5483: https://github.com/usememos/memos/issues/5483 +func TestCreateMemoWithCustomTimestamps(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a test user + user, err := ts.CreateRegularUser(ctx, "test-user-timestamps") + require.NoError(t, err) + require.NotNil(t, user) + + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Define custom timestamps (January 1, 2020) + customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC) + customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC) + customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC) + + // Test 1: Create a memo with custom create_time + memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "This memo has a custom creation time", + Visibility: apiv1.Visibility_PRIVATE, + CreateTime: timestamppb.New(customCreateTime), + }, + }) + require.NoError(t, err) + require.NotNil(t, memoWithCreateTime) + require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp") + + // Test 2: Create a memo with custom update_time + memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "This memo has a custom update time", + Visibility: apiv1.Visibility_PRIVATE, + UpdateTime: timestamppb.New(customUpdateTime), + }, + }) + require.NoError(t, err) + require.NotNil(t, memoWithUpdateTime) + require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp") + + // Test 3: Create a memo with custom display_time + // Note: display_time is computed from either created_ts or updated_ts based on instance setting + // Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts + memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "This memo has a custom display time", + Visibility: apiv1.Visibility_PRIVATE, + DisplayTime: timestamppb.New(customDisplayTime), + }, + }) + require.NoError(t, err) + require.NotNil(t, memoWithDisplayTime) + // Since DisplayWithUpdateTime is false by default, display_time sets created_ts + require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp") + require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts") + + // Test 4: Create a memo with all custom timestamps + // When both display_time and create_time are provided, create_time takes precedence + memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "This memo has all custom timestamps", + Visibility: apiv1.Visibility_PRIVATE, + CreateTime: timestamppb.New(customCreateTime), + UpdateTime: timestamppb.New(customUpdateTime), + DisplayTime: timestamppb.New(customDisplayTime), + }, + }) + require.NoError(t, err) + require.NotNil(t, memoWithAllTimestamps) + require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp") + require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp") + // display_time is computed from created_ts when DisplayWithUpdateTime is false + require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time") + + // Test 5: Create a comment (memo relation) with custom timestamps + parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "This is the parent memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, parentMemo) + + customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC) + comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{ + Name: parentMemo.Name, + Comment: &apiv1.Memo{ + Content: "This is a comment with custom create time", + Visibility: apiv1.Visibility_PRIVATE, + CreateTime: timestamppb.New(customCommentCreateTime), + }, + }) + require.NoError(t, err) + require.NotNil(t, comment) + require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp") + + // Test 6: Verify that memos without custom timestamps still get auto-generated ones + memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "This memo has auto-generated timestamps", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memoWithoutTimestamps) + require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated") + require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated") + require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)") +} diff --git a/server/router/api/v1/test/test_helper.go b/server/router/api/v1/test/test_helper.go index e14fab738..779ad2eea 100644 --- a/server/router/api/v1/test/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -29,7 +29,7 @@ func NewTestService(t *testing.T) *TestService { // Create a test profile testProfile := &profile.Profile{ - Mode: "dev", + Demo: true, Version: "test-1.0.0", InstanceURL: "http://localhost:8080", Driver: "sqlite", @@ -56,17 +56,16 @@ func NewTestService(t *testing.T) *TestService { } } -// Cleanup clears caches and closes resources after test. +// Cleanup closes resources after test. func (ts *TestService) Cleanup() { ts.Store.Close() - // Note: Owner cache is package-level in parent package, cannot clear from test package } -// CreateHostUser creates a host user for testing. +// CreateHostUser creates an admin user for testing. func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) { return ts.Store.CreateUser(ctx, &store.User{ Username: username, - Role: store.RoleHost, + Role: store.RoleAdmin, Email: username + "@example.com", }) } diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index 6260a9547..794949cdd 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -37,7 +37,7 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq if currentUser == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin { + if currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -132,17 +132,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR // Determine the role to assign var roleToAssign store.Role if isFirstUser { - // First-time setup: create the first user as HOST (no authentication required) - roleToAssign = store.RoleHost - } else if currentUser != nil && currentUser.Role == store.RoleHost { - // Authenticated HOST user can create users with any role specified in request + // First-time setup: create the first user as ADMIN (no authentication required) + roleToAssign = store.RoleAdmin + } else if currentUser != nil && currentUser.Role == store.RoleAdmin { + // Authenticated ADMIN user can create users with any role specified in request if request.User.Role != v1pb.User_ROLE_UNSPECIFIED { roleToAssign = convertUserRoleToStore(request.User.Role) } else { roleToAssign = store.RoleUser } } else { - // Unauthenticated or non-HOST users can only create normal users + // Unauthenticated or non-ADMIN users can only create normal users roleToAssign = store.RoleUser } @@ -192,9 +192,12 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } // Check permission. // Only allow admin or self to update user. - if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost { + if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -261,7 +264,7 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR update.Description = &request.User.Description case "role": // Only allow admin to update role. - if currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost { + if currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } role := convertUserRoleToStore(request.User.Role) @@ -298,7 +301,10 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } - if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost { + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -539,7 +545,7 @@ func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1 claims := auth.GetUserClaims(ctx) if claims == nil || claims.UserID != userID { currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin) { + if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } } @@ -686,7 +692,7 @@ func (s *APIV1Service) ListUserWebhooks(ctx context.Context, request *v1pb.ListU if currentUser == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin { + if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -718,7 +724,7 @@ func (s *APIV1Service) CreateUserWebhook(ctx context.Context, request *v1pb.Crea if currentUser == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin { + if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -758,7 +764,7 @@ func (s *APIV1Service) UpdateUserWebhook(ctx context.Context, request *v1pb.Upda if currentUser == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin { + if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -830,7 +836,7 @@ func (s *APIV1Service) DeleteUserWebhook(ctx context.Context, request *v1pb.Dele if currentUser == nil { return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin { + if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -925,8 +931,6 @@ func convertUserFromStore(user *store.User) *v1pb.User { func convertUserRoleFromStore(role store.Role) v1pb.User_Role { switch role { - case store.RoleHost: - return v1pb.User_HOST case store.RoleAdmin: return v1pb.User_ADMIN case store.RoleUser: @@ -938,8 +942,6 @@ func convertUserRoleFromStore(role store.Role) v1pb.User_Role { func convertUserRoleToStore(role v1pb.User_Role) store.Role { switch role { - case v1pb.User_HOST: - return store.RoleHost case v1pb.User_ADMIN: return store.RoleAdmin default: @@ -1240,6 +1242,9 @@ func (s *APIV1Service) ListUserNotifications(ctx context.Context, request *v1pb. if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } if currentUser.ID != userID { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -1287,6 +1292,9 @@ func (s *APIV1Service) UpdateUserNotification(ctx context.Context, request *v1pb return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } // Verify ownership before updating inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{ ID: ¬ificationID, @@ -1352,6 +1360,9 @@ func (s *APIV1Service) DeleteUserNotification(ctx context.Context, request *v1pb return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } // Verify ownership before deletion inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{ ID: ¬ificationID, diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 74b342fa2..694b5bbc3 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -59,7 +59,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech ctx := r.Context() // Get the RPC method name from context (set by grpc-gateway after routing) - rpcMethod, _ := runtime.RPCMethod(ctx) + rpcMethod, ok := runtime.RPCMethod(ctx) // Extract credentials from HTTP headers authHeader := r.Header.Get("Authorization") @@ -67,7 +67,8 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech result := authenticator.Authenticate(ctx, authHeader) // Enforce authentication for non-public methods - if result == nil && !IsPublicMethod(rpcMethod) { + // If rpcMethod cannot be determined, allow through, service layer will handle visibility checks + if result == nil && ok && !IsPublicMethod(rpcMethod) { http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized) return } @@ -125,7 +126,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech gwGroup.Any("/file/*", handler) // Connect handlers for browser clients (replaces grpc-web). - logStacktraces := s.Profile.IsDev() + logStacktraces := s.Profile.Demo connectInterceptors := connect.WithInterceptors( NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first NewLoggingInterceptor(logStacktraces), diff --git a/server/router/fileserver/fileserver.go b/server/router/fileserver/fileserver.go index c49d92e3a..0aaffd42f 100644 --- a/server/router/fileserver/fileserver.go +++ b/server/router/fileserver/fileserver.go @@ -26,13 +26,55 @@ import ( "github.com/usememos/memos/store" ) +// Constants for file serving configuration. const ( - // ThumbnailCacheFolder is the folder name where the thumbnail images are stored. + // ThumbnailCacheFolder is the folder name where thumbnail images are stored. ThumbnailCacheFolder = ".thumbnail_cache" - // thumbnailMaxSize is the maximum size in pixels for the largest dimension of the thumbnail image. + + // thumbnailMaxSize is the maximum dimension (width or height) for thumbnails. thumbnailMaxSize = 600 + + // maxConcurrentThumbnails limits concurrent thumbnail generation to prevent memory exhaustion. + maxConcurrentThumbnails = 3 + + // cacheMaxAge is the max-age value for Cache-Control headers (1 hour). + cacheMaxAge = "public, max-age=3600" ) +// xssUnsafeTypes contains MIME types that could execute scripts if served directly. +// These are served as application/octet-stream to prevent XSS attacks. +var xssUnsafeTypes = map[string]bool{ + "text/html": true, + "text/javascript": true, + "application/javascript": true, + "application/x-javascript": true, + "text/xml": true, + "application/xml": true, + "application/xhtml+xml": true, + "image/svg+xml": true, +} + +// thumbnailSupportedTypes contains image MIME types that support thumbnail generation. +var thumbnailSupportedTypes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/heic": true, + "image/heif": true, + "image/webp": true, +} + +// avatarAllowedTypes contains MIME types allowed for user avatars. +var avatarAllowedTypes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/jpg": true, + "image/gif": true, + "image/webp": true, + "image/heic": true, + "image/heif": true, +} + +// SupportedThumbnailMimeTypes is the exported list of thumbnail-supported MIME types. var SupportedThumbnailMimeTypes = []string{ "image/png", "image/jpeg", @@ -41,15 +83,16 @@ var SupportedThumbnailMimeTypes = []string{ "image/webp", } +// dataURIRegex parses data URI format: data:image/png;base64,iVBORw0KGgo... +var dataURIRegex = regexp.MustCompile(`^data:(?P[^;]+);base64,(?P.+)`) + // FileServerService handles HTTP file serving with proper range request support. -// This service bypasses gRPC-Gateway to use native HTTP serving via http.ServeContent(), -// which is required for Safari video/audio playback. type FileServerService struct { Profile *profile.Profile Store *store.Store authenticator *auth.Authenticator - // thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion + // thumbnailSemaphore limits concurrent thumbnail generation. thumbnailSemaphore *semaphore.Weighted } @@ -59,29 +102,27 @@ func NewFileServerService(profile *profile.Profile, store *store.Store, secret s Profile: profile, Store: store, authenticator: auth.NewAuthenticator(store, secret), - thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations + thumbnailSemaphore: semaphore.NewWeighted(maxConcurrentThumbnails), } } // RegisterRoutes registers HTTP file serving routes. func (s *FileServerService) RegisterRoutes(echoServer *echo.Echo) { fileGroup := echoServer.Group("/file") - - // Serve attachment binary files fileGroup.GET("/attachments/:uid/:filename", s.serveAttachmentFile) - - // Serve user avatar images fileGroup.GET("/users/:identifier/avatar", s.serveUserAvatar) } +// ============================================================================= +// HTTP Handlers +// ============================================================================= + // serveAttachmentFile serves attachment binary content using native HTTP. -// This properly handles range requests required by Safari for video/audio playback. func (s *FileServerService) serveAttachmentFile(c echo.Context) error { ctx := c.Request().Context() uid := c.Param("uid") - thumbnail := c.QueryParam("thumbnail") == "true" + wantThumbnail := c.QueryParam("thumbnail") == "true" - // Get attachment from database attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{ UID: &uid, GetBlob: true, @@ -93,96 +134,25 @@ func (s *FileServerService) serveAttachmentFile(c echo.Context) error { return echo.NewHTTPError(http.StatusNotFound, "attachment not found") } - // Check permissions - verify memo visibility if attachment belongs to a memo if err := s.checkAttachmentPermission(ctx, c, attachment); err != nil { return err } - // Get the binary content - blob, err := s.getAttachmentBlob(attachment) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").SetInternal(err) + contentType := s.sanitizeContentType(attachment.Type) + + // Stream video/audio to avoid loading entire file into memory. + if isMediaType(attachment.Type) { + return s.serveMediaStream(c, attachment, contentType) } - // Handle thumbnail requests for images - if thumbnail && s.isImageType(attachment.Type) { - thumbnailBlob, err := s.getOrGenerateThumbnail(ctx, attachment) - if err != nil { - // Log warning but fall back to original image - c.Logger().Warnf("failed to get thumbnail: %v", err) - } else { - blob = thumbnailBlob - } - } - - // Determine content type - contentType := attachment.Type - if strings.HasPrefix(contentType, "text/") { - contentType += "; charset=utf-8" - } - // Prevent XSS attacks by serving potentially unsafe files as octet-stream - unsafeTypes := []string{ - "text/html", - "text/javascript", - "application/javascript", - "application/x-javascript", - "text/xml", - "application/xml", - "application/xhtml+xml", - "image/svg+xml", - } - for _, unsafeType := range unsafeTypes { - if strings.EqualFold(contentType, unsafeType) { - contentType = "application/octet-stream" - break - } - } - - // Set common headers - c.Response().Header().Set("Content-Type", contentType) - c.Response().Header().Set("Cache-Control", "public, max-age=3600") - // Prevent MIME-type sniffing which could lead to XSS - c.Response().Header().Set("X-Content-Type-Options", "nosniff") - // Defense-in-depth: prevent embedding in frames and restrict content loading - c.Response().Header().Set("X-Frame-Options", "DENY") - c.Response().Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';") - // Support HDR/wide color gamut display for capable browsers - if strings.HasPrefix(contentType, "image/") || strings.HasPrefix(contentType, "video/") { - c.Response().Header().Set("Color-Gamut", "srgb, p3, rec2020") - } - - // Force download for non-media files to prevent XSS execution - if !strings.HasPrefix(contentType, "image/") && - !strings.HasPrefix(contentType, "video/") && - !strings.HasPrefix(contentType, "audio/") && - contentType != "application/pdf" { - c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", attachment.Filename)) - } - - // For video/audio: Use http.ServeContent for automatic range request support - // This is critical for Safari which REQUIRES range request support - if strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/") { - // ServeContent automatically handles: - // - Range request parsing - // - HTTP 206 Partial Content responses - // - Content-Range headers - // - Accept-Ranges: bytes header - modTime := time.Unix(attachment.UpdatedTs, 0) - http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(blob)) - return nil - } - - // For other files: Simple blob response - return c.Blob(http.StatusOK, contentType, blob) + return s.serveStaticFile(c, attachment, contentType, wantThumbnail) } // serveUserAvatar serves user avatar images. -// Supports both user ID and username as identifier. func (s *FileServerService) serveUserAvatar(c echo.Context) error { ctx := c.Request().Context() identifier := c.Param("identifier") - // Try to find user by ID or username user, err := s.getUserByIdentifier(ctx, identifier) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user").SetInternal(err) @@ -194,79 +164,327 @@ func (s *FileServerService) serveUserAvatar(c echo.Context) error { return echo.NewHTTPError(http.StatusNotFound, "avatar not found") } - // Extract image info from data URI - imageType, base64Data, err := s.extractImageInfo(user.AvatarURL) + imageType, imageData, err := s.parseDataURI(user.AvatarURL) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "failed to extract image info").SetInternal(err) + return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse avatar data").SetInternal(err) } - // Validate avatar MIME type to prevent XSS - // Supports standard formats and HDR-capable formats - allowedAvatarTypes := map[string]bool{ - "image/png": true, - "image/jpeg": true, - "image/jpg": true, - "image/gif": true, - "image/webp": true, - "image/heic": true, - "image/heif": true, - } - if !allowedAvatarTypes[imageType] { + if !avatarAllowedTypes[imageType] { return echo.NewHTTPError(http.StatusBadRequest, "invalid avatar image type") } - // Decode base64 data - imageData, err := base64.StdEncoding.DecodeString(base64Data) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "failed to decode image data").SetInternal(err) - } - - // Set cache headers for avatars - c.Response().Header().Set("Content-Type", imageType) - c.Response().Header().Set("Cache-Control", "public, max-age=3600") - c.Response().Header().Set("X-Content-Type-Options", "nosniff") - // Defense-in-depth: prevent embedding in frames - c.Response().Header().Set("X-Frame-Options", "DENY") - c.Response().Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';") + setSecurityHeaders(c) + c.Response().Header().Set(echo.HeaderContentType, imageType) + c.Response().Header().Set(echo.HeaderCacheControl, cacheMaxAge) return c.Blob(http.StatusOK, imageType, imageData) } -// getUserByIdentifier finds a user by either ID or username. -func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) { - // Try to parse as ID first - if userID, err := util.ConvertStringToInt32(identifier); err == nil { - return s.Store.GetUser(ctx, &store.FindUser{ID: &userID}) - } +// ============================================================================= +// File Serving Methods +// ============================================================================= - // Otherwise, treat as username - return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier}) +// serveMediaStream serves video/audio files using streaming to avoid memory exhaustion. +func (s *FileServerService) serveMediaStream(c echo.Context, attachment *store.Attachment, contentType string) error { + setSecurityHeaders(c) + setMediaHeaders(c, contentType, attachment.Type) + + switch attachment.StorageType { + case storepb.AttachmentStorageType_LOCAL: + filePath, err := s.resolveLocalPath(attachment.Reference) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").SetInternal(err) + } + http.ServeFile(c.Response(), c.Request(), filePath) + return nil + + case storepb.AttachmentStorageType_S3: + presignURL, err := s.getS3PresignedURL(c.Request().Context(), attachment) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate presigned URL").SetInternal(err) + } + return c.Redirect(http.StatusTemporaryRedirect, presignURL) + + default: + // Database storage fallback. + modTime := time.Unix(attachment.UpdatedTs, 0) + http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(attachment.Blob)) + return nil + } } -// extractImageInfo extracts image type and base64 data from a data URI. -// Data URI format: data:image/png;base64,iVBORw0KGgo... -func (*FileServerService) extractImageInfo(dataURI string) (string, string, error) { - dataURIRegex := regexp.MustCompile(`^data:(?P.+);base64,(?P.+)`) - matches := dataURIRegex.FindStringSubmatch(dataURI) - if len(matches) != 3 { - return "", "", errors.New("invalid data URI format") +// serveStaticFile serves non-streaming files (images, documents, etc.). +func (s *FileServerService) serveStaticFile(c echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error { + blob, err := s.getAttachmentBlob(attachment) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").SetInternal(err) } - imageType := matches[1] - base64Data := matches[2] - return imageType, base64Data, nil + + // Generate thumbnail for supported image types. + if wantThumbnail && thumbnailSupportedTypes[attachment.Type] { + if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil { + c.Logger().Warnf("failed to get thumbnail: %v", err) + } else { + blob = thumbnailBlob + } + } + + setSecurityHeaders(c) + setMediaHeaders(c, contentType, attachment.Type) + + // Force download for non-media files to prevent XSS execution. + if !strings.HasPrefix(contentType, "image/") && contentType != "application/pdf" { + c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename)) + } + + return c.Blob(http.StatusOK, contentType, blob) } +// ============================================================================= +// Storage Operations +// ============================================================================= + +// getAttachmentBlob retrieves the binary content of an attachment from storage. +func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) { + switch attachment.StorageType { + case storepb.AttachmentStorageType_LOCAL: + return s.readLocalFile(attachment.Reference) + + case storepb.AttachmentStorageType_S3: + return s.downloadFromS3(attachment) + + default: + return attachment.Blob, nil + } +} + +// getAttachmentReader returns a reader for streaming attachment content. +func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) { + switch attachment.StorageType { + case storepb.AttachmentStorageType_LOCAL: + filePath, err := s.resolveLocalPath(attachment.Reference) + if err != nil { + return nil, err + } + file, err := os.Open(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, errors.Wrap(err, "file not found") + } + return nil, errors.Wrap(err, "failed to open file") + } + return file, nil + + case storepb.AttachmentStorageType_S3: + s3Client, s3Object, err := s.createS3Client(attachment) + if err != nil { + return nil, err + } + reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key) + if err != nil { + return nil, errors.Wrap(err, "failed to stream from S3") + } + return reader, nil + + default: + return io.NopCloser(bytes.NewReader(attachment.Blob)), nil + } +} + +// resolveLocalPath converts a storage reference to an absolute file path. +func (s *FileServerService) resolveLocalPath(reference string) (string, error) { + filePath := filepath.FromSlash(reference) + if !filepath.IsAbs(filePath) { + filePath = filepath.Join(s.Profile.Data, filePath) + } + return filePath, nil +} + +// readLocalFile reads the entire contents of a local file. +func (s *FileServerService) readLocalFile(reference string) ([]byte, error) { + filePath, err := s.resolveLocalPath(reference) + if err != nil { + return nil, err + } + + file, err := os.Open(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, errors.Wrap(err, "file not found") + } + return nil, errors.Wrap(err, "failed to open file") + } + defer file.Close() + + blob, err := io.ReadAll(file) + if err != nil { + return nil, errors.Wrap(err, "failed to read file") + } + return blob, nil +} + +// createS3Client creates an S3 client from attachment payload. +func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Client, *storepb.AttachmentPayload_S3Object, error) { + if attachment.Payload == nil { + return nil, nil, errors.New("attachment payload is missing") + } + s3Object := attachment.Payload.GetS3Object() + if s3Object == nil { + return nil, nil, errors.New("S3 object payload is missing") + } + if s3Object.S3Config == nil { + return nil, nil, errors.New("S3 config is missing") + } + if s3Object.Key == "" { + return nil, nil, errors.New("S3 object key is missing") + } + + client, err := s3.NewClient(context.Background(), s3Object.S3Config) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to create S3 client") + } + return client, s3Object, nil +} + +// downloadFromS3 downloads the entire object from S3. +func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) { + client, s3Object, err := s.createS3Client(attachment) + if err != nil { + return nil, err + } + + blob, err := client.GetObject(context.Background(), s3Object.Key) + if err != nil { + return nil, errors.Wrap(err, "failed to download from S3") + } + return blob, nil +} + +// getS3PresignedURL generates a presigned URL for direct S3 access. +func (s *FileServerService) getS3PresignedURL(ctx context.Context, attachment *store.Attachment) (string, error) { + client, s3Object, err := s.createS3Client(attachment) + if err != nil { + return "", err + } + + url, err := client.PresignGetObject(ctx, s3Object.Key) + if err != nil { + return "", errors.Wrap(err, "failed to presign URL") + } + return url, nil +} + +// ============================================================================= +// Thumbnail Generation +// ============================================================================= + +// getOrGenerateThumbnail returns the thumbnail image of the attachment. +// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion. +func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) { + thumbnailPath, err := s.getThumbnailPath(attachment) + if err != nil { + return nil, err + } + + // Fast path: return cached thumbnail if exists. + if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil { + return blob, nil + } + + // Acquire semaphore to limit concurrent generation. + if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil { + return nil, errors.Wrap(err, "failed to acquire semaphore") + } + defer s.thumbnailSemaphore.Release(1) + + // Double-check after acquiring semaphore (another goroutine may have generated it). + if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil { + return blob, nil + } + + return s.generateThumbnail(attachment, thumbnailPath) +} + +// getThumbnailPath returns the file path for a cached thumbnail. +func (s *FileServerService) getThumbnailPath(attachment *store.Attachment) (string, error) { + cacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder) + if err := os.MkdirAll(cacheFolder, os.ModePerm); err != nil { + return "", errors.Wrap(err, "failed to create thumbnail cache folder") + } + filename := fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename)) + return filepath.Join(cacheFolder, filename), nil +} + +// readCachedThumbnail reads a thumbnail from the cache directory. +func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + return io.ReadAll(file) +} + +// generateThumbnail creates a new thumbnail and saves it to disk. +func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) { + reader, err := s.getAttachmentReader(attachment) + if err != nil { + return nil, errors.Wrap(err, "failed to get attachment reader") + } + defer reader.Close() + + img, err := imaging.Decode(reader, imaging.AutoOrientation(true)) + if err != nil { + return nil, errors.Wrap(err, "failed to decode image") + } + + width, height := img.Bounds().Dx(), img.Bounds().Dy() + thumbnailWidth, thumbnailHeight := calculateThumbnailDimensions(width, height) + + thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos) + + if err := imaging.Save(thumbnailImage, thumbnailPath); err != nil { + return nil, errors.Wrap(err, "failed to save thumbnail") + } + + return s.readCachedThumbnail(thumbnailPath) +} + +// calculateThumbnailDimensions calculates the target dimensions for a thumbnail. +// The largest dimension is constrained to thumbnailMaxSize while maintaining aspect ratio. +// Small images are not enlarged. +func calculateThumbnailDimensions(width, height int) (int, int) { + if max(width, height) <= thumbnailMaxSize { + return width, height + } + if width >= height { + return thumbnailMaxSize, 0 // Landscape: constrain width. + } + return 0, thumbnailMaxSize // Portrait: constrain height. +} + +// ============================================================================= +// Authentication & Authorization +// ============================================================================= + // checkAttachmentPermission verifies the user has permission to access the attachment. func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c echo.Context, attachment *store.Attachment) error { - // If attachment is not linked to a memo, allow access + // For unlinked attachments, only the creator can access. if attachment.MemoID == nil { + user, err := s.getCurrentUser(ctx, c) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").SetInternal(err) + } + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access") + } + if user.ID != attachment.CreatorID && user.Role != store.RoleAdmin { + return echo.NewHTTPError(http.StatusForbidden, "forbidden access") + } return nil } - // Check memo visibility - memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ - ID: attachment.MemoID, - }) + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID}) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "failed to find memo").SetInternal(err) } @@ -274,12 +492,10 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech return echo.NewHTTPError(http.StatusNotFound, "memo not found") } - // Public memos are accessible to everyone if memo.Visibility == store.Public { return nil } - // For non-public memos, check authentication user, err := s.getCurrentUser(ctx, c) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").SetInternal(err) @@ -288,278 +504,132 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access") } - // Private memos can only be accessed by the creator - if memo.Visibility == store.Private && user.ID != attachment.CreatorID { + if memo.Visibility == store.Private && user.ID != memo.CreatorID && user.Role != store.RoleAdmin { return echo.NewHTTPError(http.StatusForbidden, "forbidden access") } return nil } -// getCurrentUser retrieves the current authenticated user from the Echo context. +// getCurrentUser retrieves the current authenticated user from the request. // Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie. -// Uses the shared Authenticator for consistent authentication logic. func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context) (*store.User, error) { - // Try Bearer token authentication first - authHeader := c.Request().Header.Get("Authorization") - if authHeader != "" { - token := auth.ExtractBearerToken(authHeader) - if token != "" { - // Try Access Token V2 (stateless) - if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { - claims, err := s.authenticator.AuthenticateByAccessTokenV2(token) - if err == nil && claims != nil { - // Get user from claims - user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) - if err == nil && user != nil { - return user, nil - } - } - } - - // Try PAT - if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { - user, _, err := s.authenticator.AuthenticateByPAT(ctx, token) - if err == nil && user != nil { - return user, nil - } - } + // Try Bearer token authentication. + if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" { + if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil { + return user, nil } } - // Fallback: Try refresh token cookie authentication - // This allows protected attachments to load even when access token has expired, - // as long as the user has a valid refresh token cookie. - cookieHeader := c.Request().Header.Get("Cookie") - if cookieHeader != "" { - refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader) - if refreshToken != "" { - user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken) - if err == nil && user != nil { - return user, nil - } + // Fallback: Try refresh token cookie. + if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" { + if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil { + return user, nil } } - // No valid authentication found return nil, nil } -// isImageType checks if the mime type is an image that supports thumbnails. -// Supports standard formats (PNG, JPEG) and HDR-capable formats (HEIC, HEIF, WebP). -func (*FileServerService) isImageType(mimeType string) bool { - supportedTypes := map[string]bool{ - "image/png": true, - "image/jpeg": true, - "image/heic": true, - "image/heif": true, - "image/webp": true, +// authenticateByBearerToken authenticates using Authorization header. +func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) { + token := auth.ExtractBearerToken(authHeader) + if token == "" { + return nil, nil } - return supportedTypes[mimeType] + + // Try Access Token V2 (stateless JWT). + if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { + claims, err := s.authenticator.AuthenticateByAccessTokenV2(token) + if err == nil && claims != nil { + return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) + } + } + + // Try Personal Access Token (stateful). + if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { + user, _, err := s.authenticator.AuthenticateByPAT(ctx, token) + if err == nil { + return user, nil + } + } + + return nil, nil } -// getAttachmentReader returns a reader for the attachment content. -func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) { - // For local storage, read the file from the local disk. - if attachment.StorageType == storepb.AttachmentStorageType_LOCAL { - attachmentPath := filepath.FromSlash(attachment.Reference) - if !filepath.IsAbs(attachmentPath) { - attachmentPath = filepath.Join(s.Profile.Data, attachmentPath) - } - - file, err := os.Open(attachmentPath) - if err != nil { - if os.IsNotExist(err) { - return nil, errors.Wrap(err, "file not found") - } - return nil, errors.Wrap(err, "failed to open the file") - } - return file, nil +// authenticateByRefreshToken authenticates using refresh token cookie. +func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) { + refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader) + if refreshToken == "" { + return nil, nil } - // For S3 storage, download the file from S3. - if attachment.StorageType == storepb.AttachmentStorageType_S3 { - if attachment.Payload == nil { - return nil, errors.New("attachment payload is missing") - } - s3Object := attachment.Payload.GetS3Object() - if s3Object == nil { - return nil, errors.New("S3 object payload is missing") - } - if s3Object.S3Config == nil { - return nil, errors.New("S3 config is missing") - } - if s3Object.Key == "" { - return nil, errors.New("S3 object key is missing") - } - s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config) - if err != nil { - return nil, errors.Wrap(err, "failed to create S3 client") - } - - reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key) - if err != nil { - return nil, errors.Wrap(err, "failed to get object from S3") - } - return reader, nil - } - // For database storage, return the blob from the database. - return io.NopCloser(bytes.NewReader(attachment.Blob)), nil + user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken) + return user, err } -// getAttachmentBlob retrieves the binary content of an attachment from storage. -func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) { - // For local storage, read the file from the local disk. - if attachment.StorageType == storepb.AttachmentStorageType_LOCAL { - attachmentPath := filepath.FromSlash(attachment.Reference) - if !filepath.IsAbs(attachmentPath) { - attachmentPath = filepath.Join(s.Profile.Data, attachmentPath) - } - - file, err := os.Open(attachmentPath) - if err != nil { - if os.IsNotExist(err) { - return nil, errors.Wrap(err, "file not found") - } - return nil, errors.Wrap(err, "failed to open the file") - } - defer file.Close() - blob, err := io.ReadAll(file) - if err != nil { - return nil, errors.Wrap(err, "failed to read the file") - } - return blob, nil +// getUserByIdentifier finds a user by either ID or username. +func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) { + if userID, err := util.ConvertStringToInt32(identifier); err == nil { + return s.Store.GetUser(ctx, &store.FindUser{ID: &userID}) } - // For S3 storage, download the file from S3. - if attachment.StorageType == storepb.AttachmentStorageType_S3 { - if attachment.Payload == nil { - return nil, errors.New("attachment payload is missing") - } - s3Object := attachment.Payload.GetS3Object() - if s3Object == nil { - return nil, errors.New("S3 object payload is missing") - } - if s3Object.S3Config == nil { - return nil, errors.New("S3 config is missing") - } - if s3Object.Key == "" { - return nil, errors.New("S3 object key is missing") - } - - s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config) - if err != nil { - return nil, errors.Wrap(err, "failed to create S3 client") - } - - blob, err := s3Client.GetObject(context.Background(), s3Object.Key) - if err != nil { - return nil, errors.Wrap(err, "failed to get object from S3") - } - return blob, nil - } - // For database storage, return the blob from the database. - return attachment.Blob, nil + return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier}) } -// getOrGenerateThumbnail returns the thumbnail image of the attachment. -// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion. -func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) { - thumbnailCacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder) - if err := os.MkdirAll(thumbnailCacheFolder, os.ModePerm); err != nil { - return nil, errors.Wrap(err, "failed to create thumbnail cache folder") - } - filePath := filepath.Join(thumbnailCacheFolder, fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename))) +// ============================================================================= +// Helper Functions +// ============================================================================= - // Check if thumbnail already exists - if _, err := os.Stat(filePath); err == nil { - // Thumbnail exists, read and return it - thumbnailFile, err := os.Open(filePath) - if err != nil { - return nil, errors.Wrap(err, "failed to open thumbnail file") - } - defer thumbnailFile.Close() - blob, err := io.ReadAll(thumbnailFile) - if err != nil { - return nil, errors.Wrap(err, "failed to read thumbnail file") - } - return blob, nil - } else if !os.IsNotExist(err) { - return nil, errors.Wrap(err, "failed to check thumbnail image stat") +// sanitizeContentType converts potentially dangerous MIME types to safe alternatives. +func (*FileServerService) sanitizeContentType(mimeType string) string { + contentType := mimeType + if strings.HasPrefix(contentType, "text/") { + contentType += "; charset=utf-8" } - - // Thumbnail doesn't exist, acquire semaphore to limit concurrent generation - if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil { - return nil, errors.Wrap(err, "failed to acquire thumbnail generation semaphore") + // Normalize for case-insensitive lookup. + if xssUnsafeTypes[strings.ToLower(mimeType)] { + return "application/octet-stream" + } + return contentType +} + +// parseDataURI extracts MIME type and decoded data from a data URI. +func (*FileServerService) parseDataURI(dataURI string) (string, []byte, error) { + matches := dataURIRegex.FindStringSubmatch(dataURI) + if len(matches) != 3 { + return "", nil, errors.New("invalid data URI format") + } + + imageType := matches[1] + imageData, err := base64.StdEncoding.DecodeString(matches[2]) + if err != nil { + return "", nil, errors.Wrap(err, "failed to decode base64 data") + } + + return imageType, imageData, nil +} + +// isMediaType checks if the MIME type is video or audio. +func isMediaType(mimeType string) bool { + return strings.HasPrefix(mimeType, "video/") || strings.HasPrefix(mimeType, "audio/") +} + +// setSecurityHeaders sets common security headers for all responses. +func setSecurityHeaders(c echo.Context) { + h := c.Response().Header() + h.Set("X-Content-Type-Options", "nosniff") + h.Set("X-Frame-Options", "DENY") + h.Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';") +} + +// setMediaHeaders sets headers for media file responses. +func setMediaHeaders(c echo.Context, contentType, originalType string) { + h := c.Response().Header() + h.Set(echo.HeaderContentType, contentType) + h.Set(echo.HeaderCacheControl, cacheMaxAge) + + // Support HDR/wide color gamut for images and videos. + if strings.HasPrefix(originalType, "image/") || strings.HasPrefix(originalType, "video/") { + h.Set("Color-Gamut", "srgb, p3, rec2020") } - defer s.thumbnailSemaphore.Release(1) - - // Double-check if thumbnail was created while waiting for semaphore - if _, err := os.Stat(filePath); err == nil { - thumbnailFile, err := os.Open(filePath) - if err != nil { - return nil, errors.Wrap(err, "failed to open thumbnail file") - } - defer thumbnailFile.Close() - blob, err := io.ReadAll(thumbnailFile) - if err != nil { - return nil, errors.Wrap(err, "failed to read thumbnail file") - } - return blob, nil - } - - // Generate the thumbnail - reader, err := s.getAttachmentReader(attachment) - if err != nil { - return nil, errors.Wrap(err, "failed to get attachment reader") - } - defer reader.Close() - - // Decode image - this is memory intensive - img, err := imaging.Decode(reader, imaging.AutoOrientation(true)) - if err != nil { - return nil, errors.Wrap(err, "failed to decode thumbnail image") - } - - // The largest dimension is set to thumbnailMaxSize and the smaller dimension is scaled proportionally. - // Small images are not enlarged. - width := img.Bounds().Dx() - height := img.Bounds().Dy() - var thumbnailWidth, thumbnailHeight int - - // Only resize if the image is larger than thumbnailMaxSize - if max(width, height) > thumbnailMaxSize { - if width >= height { - // Landscape or square - constrain width, maintain aspect ratio for height - thumbnailWidth = thumbnailMaxSize - thumbnailHeight = 0 - } else { - // Portrait - constrain height, maintain aspect ratio for width - thumbnailWidth = 0 - thumbnailHeight = thumbnailMaxSize - } - } else { - // Keep original dimensions for small images - thumbnailWidth = width - thumbnailHeight = height - } - - // Resize the image to the calculated dimensions. - thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos) - - // Save thumbnail to disk - if err := imaging.Save(thumbnailImage, filePath); err != nil { - return nil, errors.Wrap(err, "failed to save thumbnail file") - } - - // Read the saved thumbnail and return it - thumbnailFile, err := os.Open(filePath) - if err != nil { - return nil, errors.Wrap(err, "failed to open thumbnail file") - } - defer thumbnailFile.Close() - thumbnailBlob, err := io.ReadAll(thumbnailFile) - if err != nil { - return nil, errors.Wrap(err, "failed to read thumbnail file") - } - return thumbnailBlob, nil } diff --git a/server/server.go b/server/server.go index a1fbc1ffe..af09c4bcd 100644 --- a/server/server.go +++ b/server/server.go @@ -52,7 +52,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } secret := "usememos" - if profile.Mode == "prod" { + if !profile.Demo { secret = instanceBasicSetting.SecretKey } s.Secret = secret diff --git a/store/cache/cache.go b/store/cache/cache.go index 8760b3ad5..102f8add3 100644 --- a/store/cache/cache.go +++ b/store/cache/cache.go @@ -161,20 +161,20 @@ func (c *Cache) Delete(_ context.Context, key string) { // Clear removes all values from the cache. func (c *Cache) Clear(_ context.Context) { - if c.config.OnEviction != nil { - c.data.Range(func(key, value any) bool { + count := 0 + c.data.Range(func(key, value any) bool { + if c.config.OnEviction != nil { itm, ok := value.(item) - if !ok { - return true + if ok { + if keyStr, ok := key.(string); ok { + c.config.OnEviction(keyStr, itm.value) + } } - if keyStr, ok := key.(string); ok { - c.config.OnEviction(keyStr, itm.value) - } - return true - }) - } - - c.data = sync.Map{} + } + c.data.Delete(key) + count++ + return true + }) (&c.itemCount).Store(0) } diff --git a/store/db/mysql/attachment.go b/store/db/mysql/attachment.go index ead254d88..b313d34af 100644 --- a/store/db/mysql/attachment.go +++ b/store/db/mysql/attachment.go @@ -31,7 +31,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s } args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString} - stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")" + stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")" result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return nil, err @@ -50,38 +50,38 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, "`resource`.`id` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`id` = ?"), append(args, *v) } if v := find.UID; v != nil { - where, args = append(where, "`resource`.`uid` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v) } if v := find.Filename; v != nil { - where, args = append(where, "`resource`.`filename` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v) } if v := find.FilenameSearch; v != nil { - where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, "%"+*v+"%") + where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, "%"+*v+"%") } if v := find.MemoID; v != nil { - where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v) } if len(find.MemoIDList) > 0 { placeholders := make([]string, 0, len(find.MemoIDList)) for range find.MemoIDList { placeholders = append(placeholders, "?") } - where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")") + where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")") for _, id := range find.MemoIDList { args = append(args, id) } } if find.HasRelatedMemo { - where = append(where, "`resource`.`memo_id` IS NOT NULL") + where = append(where, "`attachment`.`memo_id` IS NOT NULL") } if find.StorageType != nil { - where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String()) + where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String()) } if len(find.Filters) > 0 { @@ -95,26 +95,26 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ } fields := []string{ - "`resource`.`id` AS `id`", - "`resource`.`uid` AS `uid`", - "`resource`.`filename` AS `filename`", - "`resource`.`type` AS `type`", - "`resource`.`size` AS `size`", - "`resource`.`creator_id` AS `creator_id`", - "UNIX_TIMESTAMP(`resource`.`created_ts`) AS `created_ts`", - "UNIX_TIMESTAMP(`resource`.`updated_ts`) AS `updated_ts`", - "`resource`.`memo_id` AS `memo_id`", - "`resource`.`storage_type` AS `storage_type`", - "`resource`.`reference` AS `reference`", - "`resource`.`payload` AS `payload`", + "`attachment`.`id` AS `id`", + "`attachment`.`uid` AS `uid`", + "`attachment`.`filename` AS `filename`", + "`attachment`.`type` AS `type`", + "`attachment`.`size` AS `size`", + "`attachment`.`creator_id` AS `creator_id`", + "UNIX_TIMESTAMP(`attachment`.`created_ts`) AS `created_ts`", + "UNIX_TIMESTAMP(`attachment`.`updated_ts`) AS `updated_ts`", + "`attachment`.`memo_id` AS `memo_id`", + "`attachment`.`storage_type` AS `storage_type`", + "`attachment`.`reference` AS `reference`", + "`attachment`.`payload` AS `payload`", "CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`", } if find.GetBlob { - fields = append(fields, "`resource`.`blob` AS `blob`") + fields = append(fields, "`attachment`.`blob` AS `blob`") } - query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " + - "LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " + + query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " + + "LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " + "WHERE " + strings.Join(where, " AND ") + " " + "ORDER BY `updated_ts` DESC" if find.Limit != nil { @@ -216,7 +216,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen } args = append(args, update.ID) - stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?" + stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?" result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return err @@ -228,7 +228,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen } func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error { - stmt := "DELETE FROM `resource` WHERE `id` = ?" + stmt := "DELETE FROM `attachment` WHERE `id` = ?" result, err := d.db.ExecContext(ctx, stmt, delete.ID) if err != nil { return err diff --git a/store/db/mysql/inbox.go b/store/db/mysql/inbox.go index ec20a8ebe..9964bf9c2 100644 --- a/store/db/mysql/inbox.go +++ b/store/db/mysql/inbox.go @@ -63,7 +63,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I if find.MessageType != nil { // Filter by message type using JSON extraction // Note: The type field in JSON is stored as string representation of the enum name - where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String()) + if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED { + where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String()) + } else { + where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String()) + } } query := "SELECT `id`, UNIX_TIMESTAMP(`created_ts`), `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC" diff --git a/store/db/mysql/memo.go b/store/db/mysql/memo.go index 2f9bc2ebb..05d45ea27 100644 --- a/store/db/mysql/memo.go +++ b/store/db/mysql/memo.go @@ -26,6 +26,18 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e } args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload} + // Add custom timestamps if provided + if create.CreatedTs != 0 { + fields = append(fields, "`created_ts`") + placeholder = append(placeholder, "?") + args = append(args, create.CreatedTs) + } + if create.UpdatedTs != 0 { + fields = append(fields, "`updated_ts`") + placeholder = append(placeholder, "?") + args = append(args, create.UpdatedTs) + } + stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")" result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go deleted file mode 100644 index 020f2a28c..000000000 --- a/store/db/mysql/memo_filter_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package mysql - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `tag in ["tag1", "tag2"]`, - want: "((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))", - args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`}, - }, - { - filter: `!(tag in ["tag1", "tag2"])`, - want: "NOT (((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)))", - args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`}, - }, - { - filter: `content.contains("memos")`, - want: "`memo`.`content` LIKE ?", - args: []any{"%memos%"}, - }, - { - filter: `visibility in ["PUBLIC"]`, - want: "`memo`.`visibility` IN (?)", - args: []any{"PUBLIC"}, - }, - { - filter: `visibility in ["PUBLIC", "PRIVATE"]`, - want: "`memo`.`visibility` IN (?,?)", - args: []any{"PUBLIC", "PRIVATE"}, - }, - { - filter: `tag in ['tag1'] || content.contains('hello')`, - want: "((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR `memo`.`content` LIKE ?)", - args: []any{`"tag1"`, `%"tag1/%`, "%hello%"}, - }, - { - filter: `1`, - want: "", - args: []any{}, - }, - { - filter: `pinned`, - want: "`memo`.`pinned` IS TRUE", - args: []any{}, - }, - { - filter: `has_task_list`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)", - args: []any{}, - }, - { - filter: `has_task_list == true`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)", - args: []any{}, - }, - { - filter: `has_task_list != false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)", - args: []any{}, - }, - { - filter: `has_task_list == false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)", - args: []any{}, - }, - { - filter: `!has_task_list`, - want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON))", - args: []any{}, - }, - { - filter: `has_task_list && pinned`, - want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`pinned` IS TRUE)", - args: []any{}, - }, - { - filter: `has_task_list && content.contains("todo")`, - want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`content` LIKE ?)", - args: []any{"%todo%"}, - }, - { - filter: `created_ts > now() - 60 * 60 * 24`, - want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?", - args: []any{time.Now().Unix() - 60*60*24}, - }, - { - filter: `size(tags) == 0`, - want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", - args: []any{int64(0)}, - }, - { - filter: `size(tags) > 0`, - want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?", - args: []any{int64(0)}, - }, - { - filter: `"work" in tags`, - want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)", - args: []any{`"work"`}, - }, - { - filter: `size(tags) == 2`, - want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", - args: []any{int64(2)}, - }, - { - filter: `has_link == true`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)", - args: []any{}, - }, - { - filter: `has_code == false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('false' AS JSON)", - args: []any{}, - }, - { - filter: `has_incomplete_tasks != false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != CAST('false' AS JSON)", - args: []any{}, - }, - { - filter: `has_link`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)", - args: []any{}, - }, - { - filter: `has_code`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)", - args: []any{}, - }, - { - filter: `has_incomplete_tasks`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)", - args: []any{}, - }, - } - - engine, err := filter.DefaultEngine() - require.NoError(t, err) - - for _, tt := range tests { - stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{ - Dialect: filter.DialectMySQL, - }) - require.NoError(t, err) - require.Equal(t, tt.want, stmt.SQL) - require.Equal(t, tt.args, stmt.Args) - } -} diff --git a/store/db/mysql/memo_relation.go b/store/db/mysql/memo_relation.go index 3116903e0..71b73be6f 100644 --- a/store/db/mysql/memo_relation.go +++ b/store/db/mysql/memo_relation.go @@ -10,7 +10,7 @@ import ( ) func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) { - stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?)" + stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `type` = `type`" _, err := d.db.ExecContext( ctx, stmt, diff --git a/store/db/postgres/attachment.go b/store/db/postgres/attachment.go index 9ee970fbd..3d51acd2d 100644 --- a/store/db/postgres/attachment.go +++ b/store/db/postgres/attachment.go @@ -30,7 +30,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s } args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString} - stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts" + stmt := "INSERT INTO attachment (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts" if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil { return nil, err } @@ -41,22 +41,22 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "attachment.id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.UID; v != nil { - where, args = append(where, "resource.uid = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "attachment.uid = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "resource.creator_id = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "attachment.creator_id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Filename; v != nil { - where, args = append(where, "resource.filename = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "attachment.filename = "+placeholder(len(args)+1)), append(args, *v) } if v := find.FilenameSearch; v != nil { - where, args = append(where, "resource.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v)) + where, args = append(where, "attachment.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v)) } if v := find.MemoID; v != nil { - where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "attachment.memo_id = "+placeholder(len(args)+1)), append(args, *v) } if len(find.MemoIDList) > 0 { holders := make([]string, 0, len(find.MemoIDList)) @@ -64,13 +64,13 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ holders = append(holders, placeholder(len(args)+1)) args = append(args, id) } - where = append(where, "resource.memo_id IN ("+strings.Join(holders, ", ")+")") + where = append(where, "attachment.memo_id IN ("+strings.Join(holders, ", ")+")") } if find.HasRelatedMemo { - where = append(where, "resource.memo_id IS NOT NULL") + where = append(where, "attachment.memo_id IS NOT NULL") } if v := find.StorageType; v != nil { - where, args = append(where, "resource.storage_type = "+placeholder(len(args)+1)), append(args, v.String()) + where, args = append(where, "attachment.storage_type = "+placeholder(len(args)+1)), append(args, v.String()) } if len(find.Filters) > 0 { @@ -84,31 +84,31 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ } fields := []string{ - "resource.id AS id", - "resource.uid AS uid", - "resource.filename AS filename", - "resource.type AS type", - "resource.size AS size", - "resource.creator_id AS creator_id", - "resource.created_ts AS created_ts", - "resource.updated_ts AS updated_ts", - "resource.memo_id AS memo_id", - "resource.storage_type AS storage_type", - "resource.reference AS reference", - "resource.payload AS payload", + "attachment.id AS id", + "attachment.uid AS uid", + "attachment.filename AS filename", + "attachment.type AS type", + "attachment.size AS size", + "attachment.creator_id AS creator_id", + "attachment.created_ts AS created_ts", + "attachment.updated_ts AS updated_ts", + "attachment.memo_id AS memo_id", + "attachment.storage_type AS storage_type", + "attachment.reference AS reference", + "attachment.payload AS payload", "CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid", } if find.GetBlob { - fields = append(fields, "resource.blob AS blob") + fields = append(fields, "attachment.blob AS blob") } query := fmt.Sprintf(` SELECT %s - FROM resource - LEFT JOIN memo ON resource.memo_id = memo.id + FROM attachment + LEFT JOIN memo ON attachment.memo_id = memo.id WHERE %s - ORDER BY resource.updated_ts DESC + ORDER BY attachment.updated_ts DESC `, strings.Join(fields, ", "), strings.Join(where, " AND ")) if find.Limit != nil { query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) @@ -196,7 +196,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(bytes)) } - stmt := `UPDATE resource SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) + stmt := `UPDATE attachment SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) args = append(args, update.ID) result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { @@ -209,7 +209,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen } func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error { - stmt := `DELETE FROM resource WHERE id = $1` + stmt := `DELETE FROM attachment WHERE id = $1` result, err := d.db.ExecContext(ctx, stmt, delete.ID) if err != nil { return err diff --git a/store/db/postgres/inbox.go b/store/db/postgres/inbox.go index 7df32e287..40be94b3b 100644 --- a/store/db/postgres/inbox.go +++ b/store/db/postgres/inbox.go @@ -53,7 +53,12 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I if find.MessageType != nil { // Filter by message type using PostgreSQL JSON extraction // Note: The type field in JSON is stored as string representation of the enum name - where, args = append(where, "message->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String()) + // Cast to JSONB since the column is TEXT + if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED { + where, args = append(where, "(message::JSONB->>'type' IS NULL OR message::JSONB->>'type' = "+placeholder(len(args)+1)+")"), append(args, find.MessageType.String()) + } else { + where, args = append(where, "message::JSONB->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String()) + } } query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC" diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index 3fa3abd4b..fd25a13ed 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -25,6 +25,16 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e } args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload} + // Add custom timestamps if provided + if create.CreatedTs != 0 { + fields = append(fields, "created_ts") + args = append(args, create.CreatedTs) + } + if create.UpdatedTs != 0 { + fields = append(fields, "updated_ts") + args = append(args, create.UpdatedTs) + } + stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status" if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( &create.ID, diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go deleted file mode 100644 index b6ac67309..000000000 --- a/store/db/postgres/memo_filter_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package postgres - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `tag in ["tag1", "tag2"]`, - want: "((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR (memo.payload->'tags' @> jsonb_build_array($3::json) OR (memo.payload->'tags')::text LIKE $4))", - args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`}, - }, - { - filter: `!(tag in ["tag1", "tag2"])`, - want: "NOT (((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR (memo.payload->'tags' @> jsonb_build_array($3::json) OR (memo.payload->'tags')::text LIKE $4)))", - args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`}, - }, - { - filter: `content.contains("memos")`, - want: "memo.content ILIKE $1", - args: []any{"%memos%"}, - }, - { - filter: `visibility in ["PUBLIC"]`, - want: "memo.visibility IN ($1)", - args: []any{"PUBLIC"}, - }, - { - filter: `visibility in ["PUBLIC", "PRIVATE"]`, - want: "memo.visibility IN ($1,$2)", - args: []any{"PUBLIC", "PRIVATE"}, - }, - { - filter: `tag in ['tag1'] || content.contains('hello')`, - want: "((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR memo.content ILIKE $3)", - args: []any{`"tag1"`, `%"tag1/%`, "%hello%"}, - }, - { - filter: `1`, - want: "", - args: []any{}, - }, - { - filter: `pinned`, - want: "memo.pinned IS TRUE", - args: []any{}, - }, - { - filter: `has_task_list`, - want: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE", - args: []any{}, - }, - { - filter: `has_task_list == true`, - want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1", - args: []any{true}, - }, - { - filter: `has_task_list != false`, - want: "(memo.payload->'property'->>'hasTaskList')::boolean != $1", - args: []any{false}, - }, - { - filter: `has_task_list == false`, - want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1", - args: []any{false}, - }, - { - filter: `!has_task_list`, - want: "NOT ((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE)", - args: []any{}, - }, - { - filter: `has_task_list && pinned`, - want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.pinned IS TRUE)", - args: []any{}, - }, - { - filter: `has_task_list && content.contains("todo")`, - want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.content ILIKE $1)", - args: []any{"%todo%"}, - }, - { - filter: `created_ts > now() - 60 * 60 * 24`, - want: "memo.created_ts > $1", - args: []any{time.Now().Unix() - 60*60*24}, - }, - { - filter: `size(tags) == 0`, - want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1", - args: []any{int64(0)}, - }, - { - filter: `size(tags) > 0`, - want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1", - args: []any{int64(0)}, - }, - { - filter: `"work" in tags`, - want: "memo.payload->'tags' @> jsonb_build_array($1::json)", - args: []any{`"work"`}, - }, - { - filter: `size(tags) == 2`, - want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1", - args: []any{int64(2)}, - }, - { - filter: `has_link == true`, - want: "(memo.payload->'property'->>'hasLink')::boolean = $1", - args: []any{true}, - }, - { - filter: `has_code == false`, - want: "(memo.payload->'property'->>'hasCode')::boolean = $1", - args: []any{false}, - }, - { - filter: `has_incomplete_tasks != false`, - want: "(memo.payload->'property'->>'hasIncompleteTasks')::boolean != $1", - args: []any{false}, - }, - { - filter: `has_link`, - want: "(memo.payload->'property'->>'hasLink')::boolean IS TRUE", - args: []any{}, - }, - { - filter: `has_code`, - want: "(memo.payload->'property'->>'hasCode')::boolean IS TRUE", - args: []any{}, - }, - { - filter: `has_incomplete_tasks`, - want: "(memo.payload->'property'->>'hasIncompleteTasks')::boolean IS TRUE", - args: []any{}, - }, - } - - engine, err := filter.DefaultEngine() - require.NoError(t, err) - - for _, tt := range tests { - stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectPostgres}) - require.NoError(t, err) - require.Equal(t, tt.want, stmt.SQL) - require.Equal(t, tt.args, stmt.Args) - } -} diff --git a/store/db/postgres/memo_relation.go b/store/db/postgres/memo_relation.go index 881291b8a..a2f2817c7 100644 --- a/store/db/postgres/memo_relation.go +++ b/store/db/postgres/memo_relation.go @@ -17,6 +17,7 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) type ) VALUES (` + placeholders(3) + `) + ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET type = EXCLUDED.type RETURNING memo_id, related_memo_id, type ` memoRelation := &store.MemoRelation{} diff --git a/store/db/sqlite/attachment.go b/store/db/sqlite/attachment.go index 04653b185..3ac8afd6f 100644 --- a/store/db/sqlite/attachment.go +++ b/store/db/sqlite/attachment.go @@ -31,7 +31,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s } args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString} - stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`" + stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`" if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil { return nil, err } @@ -43,38 +43,38 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, "`resource`.`id` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`id` = ?"), append(args, *v) } if v := find.UID; v != nil { - where, args = append(where, "`resource`.`uid` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v) } if v := find.Filename; v != nil { - where, args = append(where, "`resource`.`filename` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v) } if v := find.FilenameSearch; v != nil { - where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v)) + where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v)) } if v := find.MemoID; v != nil { - where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v) + where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v) } if len(find.MemoIDList) > 0 { placeholders := make([]string, 0, len(find.MemoIDList)) for range find.MemoIDList { placeholders = append(placeholders, "?") } - where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")") + where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")") for _, id := range find.MemoIDList { args = append(args, id) } } if find.HasRelatedMemo { - where = append(where, "`resource`.`memo_id` IS NOT NULL") + where = append(where, "`attachment`.`memo_id` IS NOT NULL") } if find.StorageType != nil { - where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String()) + where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String()) } if len(find.Filters) > 0 { @@ -88,28 +88,28 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([ } fields := []string{ - "`resource`.`id` AS `id`", - "`resource`.`uid` AS `uid`", - "`resource`.`filename` AS `filename`", - "`resource`.`type` AS `type`", - "`resource`.`size` AS `size`", - "`resource`.`creator_id` AS `creator_id`", - "`resource`.`created_ts` AS `created_ts`", - "`resource`.`updated_ts` AS `updated_ts`", - "`resource`.`memo_id` AS `memo_id`", - "`resource`.`storage_type` AS `storage_type`", - "`resource`.`reference` AS `reference`", - "`resource`.`payload` AS `payload`", + "`attachment`.`id` AS `id`", + "`attachment`.`uid` AS `uid`", + "`attachment`.`filename` AS `filename`", + "`attachment`.`type` AS `type`", + "`attachment`.`size` AS `size`", + "`attachment`.`creator_id` AS `creator_id`", + "`attachment`.`created_ts` AS `created_ts`", + "`attachment`.`updated_ts` AS `updated_ts`", + "`attachment`.`memo_id` AS `memo_id`", + "`attachment`.`storage_type` AS `storage_type`", + "`attachment`.`reference` AS `reference`", + "`attachment`.`payload` AS `payload`", "CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`", } if find.GetBlob { - fields = append(fields, "`resource`.`blob` AS `blob`") + fields = append(fields, "`attachment`.`blob` AS `blob`") } - query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " + - "LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " + + query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " + + "LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " + "WHERE " + strings.Join(where, " AND ") + " " + - "ORDER BY `resource`.`updated_ts` DESC" + "ORDER BY `attachment`.`updated_ts` DESC" if find.Limit != nil { query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) if find.Offset != nil { @@ -197,7 +197,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen } args = append(args, update.ID) - stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?" + stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?" result, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return errors.Wrap(err, "failed to update attachment") @@ -209,7 +209,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen } func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error { - stmt := "DELETE FROM `resource` WHERE `id` = ?" + stmt := "DELETE FROM `attachment` WHERE `id` = ?" result, err := d.db.ExecContext(ctx, stmt, delete.ID) if err != nil { return err diff --git a/store/db/sqlite/inbox.go b/store/db/sqlite/inbox.go index 2ab8e68d0..bb8decbc4 100644 --- a/store/db/sqlite/inbox.go +++ b/store/db/sqlite/inbox.go @@ -55,7 +55,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I if find.MessageType != nil { // Filter by message type using JSON extraction // Note: The type field in JSON is stored as string representation of the enum name - where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String()) + if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED { + where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String()) + } else { + where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String()) + } } query := "SELECT `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC" diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go index f3bc2f54d..461d45df9 100644 --- a/store/db/sqlite/memo.go +++ b/store/db/sqlite/memo.go @@ -26,6 +26,18 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e } args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload} + // Add custom timestamps if provided + if create.CreatedTs != 0 { + fields = append(fields, "`created_ts`") + placeholder = append(placeholder, "?") + args = append(args, create.CreatedTs) + } + if create.UpdatedTs != 0 { + fields = append(fields, "`updated_ts`") + placeholder = append(placeholder, "?") + args = append(args, create.UpdatedTs) + } + stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`, `row_status`" if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( &create.ID, diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go deleted file mode 100644 index 70581b938..000000000 --- a/store/db/sqlite/memo_filter_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package sqlite - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/usememos/memos/plugin/filter" -) - -func TestConvertExprToSQL(t *testing.T) { - tests := []struct { - filter string - want string - args []any - }{ - { - filter: `tag in ["tag1", "tag2"]`, - want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))", - args: []any{`%"tag1"%`, `%"tag1/%`, `%"tag2"%`, `%"tag2/%`}, - }, - { - filter: `!(tag in ["tag1", "tag2"])`, - want: "NOT (((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)))", - args: []any{`%"tag1"%`, `%"tag1/%`, `%"tag2"%`, `%"tag2/%`}, - }, - { - filter: `content.contains("memos")`, - want: "`memo`.`content` LIKE ?", - args: []any{"%memos%"}, - }, - { - filter: `visibility in ["PUBLIC"]`, - want: "`memo`.`visibility` IN (?)", - args: []any{"PUBLIC"}, - }, - { - filter: `visibility in ["PUBLIC", "PRIVATE"]`, - want: "`memo`.`visibility` IN (?,?)", - args: []any{"PUBLIC", "PRIVATE"}, - }, - { - filter: `tag in ['tag1'] || content.contains('hello')`, - want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR `memo`.`content` LIKE ?)", - args: []any{`%"tag1"%`, `%"tag1/%`, "%hello%"}, - }, - { - filter: `1`, - want: "", - args: []any{}, - }, - { - filter: `pinned`, - want: "`memo`.`pinned` IS TRUE", - args: []any{}, - }, - { - filter: `has_task_list`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE", - args: []any{}, - }, - { - filter: `has_code`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE", - args: []any{}, - }, - { - filter: `has_task_list == true`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1", - args: []any{}, - }, - { - filter: `has_task_list != false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0", - args: []any{}, - }, - { - filter: `has_task_list == false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0", - args: []any{}, - }, - { - filter: `!has_task_list`, - want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE)", - args: []any{}, - }, - { - filter: `has_task_list && pinned`, - want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`pinned` IS TRUE)", - args: []any{}, - }, - { - filter: `has_task_list && content.contains("todo")`, - want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`content` LIKE ?)", - args: []any{"%todo%"}, - }, - { - filter: `created_ts > now() - 60 * 60 * 24`, - want: "`memo`.`created_ts` > ?", - args: []any{time.Now().Unix() - 60*60*24}, - }, - { - filter: `size(tags) == 0`, - want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", - args: []any{int64(0)}, - }, - { - filter: `size(tags) > 0`, - want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?", - args: []any{int64(0)}, - }, - { - filter: `"work" in tags`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?", - args: []any{`%"work"%`}, - }, - { - filter: `size(tags) == 2`, - want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", - args: []any{int64(2)}, - }, - { - filter: `has_link == true`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE", - args: []any{}, - }, - { - filter: `has_code == false`, - want: "NOT(JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE)", - args: []any{}, - }, - { - filter: `has_incomplete_tasks != false`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE", - args: []any{}, - }, - { - filter: `has_link`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE", - args: []any{}, - }, - { - filter: `has_code`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE", - args: []any{}, - }, - { - filter: `has_incomplete_tasks`, - want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE", - args: []any{}, - }, - } - - engine, err := filter.DefaultEngine() - require.NoError(t, err) - - for _, tt := range tests { - stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectSQLite}) - require.NoError(t, err) - require.Equal(t, tt.want, stmt.SQL) - require.Equal(t, tt.args, stmt.Args) - } -} diff --git a/store/db/sqlite/memo_relation.go b/store/db/sqlite/memo_relation.go index 3e63c7002..5eed62e74 100644 --- a/store/db/sqlite/memo_relation.go +++ b/store/db/sqlite/memo_relation.go @@ -17,6 +17,7 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) type ) VALUES (?, ?, ?) + ON CONFLICT(memo_id, related_memo_id, type) DO UPDATE SET type = excluded.type RETURNING memo_id, related_memo_id, type ` memoRelation := &store.MemoRelation{} diff --git a/store/db/sqlite/sqlite.go b/store/db/sqlite/sqlite.go index 3b4a30f8d..642e728cf 100644 --- a/store/db/sqlite/sqlite.go +++ b/store/db/sqlite/sqlite.go @@ -33,6 +33,7 @@ func NewDB(profile *profile.Profile) (store.Driver, error) { // good practice to be explicit and prevent future surprises on SQLite upgrades. // - Journal mode set to WAL: it's the recommended journal mode for most applications // as it prevents locking issues. + // - mmap size set to 0: it disables memory mapping, which can cause OOM errors on some systems. // // Notes: // - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`. @@ -41,7 +42,7 @@ func NewDB(profile *profile.Profile) (store.Driver, error) { // - https://pkg.go.dev/modernc.org/sqlite#Driver.Open // - https://www.sqlite.org/sharedcache.html // - https://www.sqlite.org/pragma.html - sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)") + sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)&_pragma=mmap_size(0)") if err != nil { return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN) } diff --git a/store/memo.go b/store/memo.go index afd71e29a..ce6cde28d 100644 --- a/store/memo.go +++ b/store/memo.go @@ -138,5 +138,22 @@ func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error { } func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { + // Clean up memo_relation records where this memo is either the source or target. + if err := s.driver.DeleteMemoRelation(ctx, &DeleteMemoRelation{MemoID: &delete.ID}); err != nil { + return err + } + if err := s.driver.DeleteMemoRelation(ctx, &DeleteMemoRelation{RelatedMemoID: &delete.ID}); err != nil { + return err + } + // Clean up attachments linked to this memo. + attachments, err := s.ListAttachments(ctx, &FindAttachment{MemoID: &delete.ID}) + if err != nil { + return err + } + for _, attachment := range attachments { + if err := s.DeleteAttachment(ctx, &DeleteAttachment{ID: attachment.ID}); err != nil { + return err + } + } return s.driver.DeleteMemo(ctx, delete) } diff --git a/store/migration/mysql/0.26/00__rename_resource_to_attachment.sql b/store/migration/mysql/0.26/00__rename_resource_to_attachment.sql new file mode 100644 index 000000000..703234fe3 --- /dev/null +++ b/store/migration/mysql/0.26/00__rename_resource_to_attachment.sql @@ -0,0 +1 @@ +RENAME TABLE resource TO attachment; diff --git a/store/migration/mysql/0.26/01__drop_memo_organizer.sql b/store/migration/mysql/0.26/01__drop_memo_organizer.sql new file mode 100644 index 000000000..17c3579f2 --- /dev/null +++ b/store/migration/mysql/0.26/01__drop_memo_organizer.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS memo_organizer; diff --git a/store/migration/mysql/0.26/02__migrate_host_to_admin.sql b/store/migration/mysql/0.26/02__migrate_host_to_admin.sql new file mode 100644 index 000000000..24ec82ac9 --- /dev/null +++ b/store/migration/mysql/0.26/02__migrate_host_to_admin.sql @@ -0,0 +1 @@ +UPDATE `user` SET `role` = 'ADMIN' WHERE `role` = 'HOST'; diff --git a/store/migration/mysql/LATEST.sql b/store/migration/mysql/LATEST.sql index adc86a9eb..a76d7111b 100644 --- a/store/migration/mysql/LATEST.sql +++ b/store/migration/mysql/LATEST.sql @@ -42,14 +42,6 @@ CREATE TABLE `memo` ( `payload` JSON NOT NULL ); --- memo_organizer -CREATE TABLE `memo_organizer` ( - `memo_id` INT NOT NULL, - `user_id` INT NOT NULL, - `pinned` INT NOT NULL DEFAULT '0', - UNIQUE(`memo_id`,`user_id`) -); - -- memo_relation CREATE TABLE `memo_relation` ( `memo_id` INT NOT NULL, @@ -58,8 +50,8 @@ CREATE TABLE `memo_relation` ( UNIQUE(`memo_id`,`related_memo_id`,`type`) ); --- resource -CREATE TABLE `resource` ( +-- attachment +CREATE TABLE `attachment` ( `id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY, `uid` VARCHAR(256) NOT NULL UNIQUE, `creator_id` INT NOT NULL, diff --git a/store/migration/postgres/0.26/00__rename_resource_to_attachment.sql b/store/migration/postgres/0.26/00__rename_resource_to_attachment.sql new file mode 100644 index 000000000..9e0e4396e --- /dev/null +++ b/store/migration/postgres/0.26/00__rename_resource_to_attachment.sql @@ -0,0 +1 @@ +ALTER TABLE resource RENAME TO attachment; diff --git a/store/migration/postgres/0.26/01__drop_memo_organizer.sql b/store/migration/postgres/0.26/01__drop_memo_organizer.sql new file mode 100644 index 000000000..17c3579f2 --- /dev/null +++ b/store/migration/postgres/0.26/01__drop_memo_organizer.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS memo_organizer; diff --git a/store/migration/postgres/0.26/02__migrate_host_to_admin.sql b/store/migration/postgres/0.26/02__migrate_host_to_admin.sql new file mode 100644 index 000000000..bd6db8024 --- /dev/null +++ b/store/migration/postgres/0.26/02__migrate_host_to_admin.sql @@ -0,0 +1 @@ +UPDATE "user" SET role = 'ADMIN' WHERE role = 'HOST'; diff --git a/store/migration/postgres/LATEST.sql b/store/migration/postgres/LATEST.sql index b5b70a9ec..cbde126cd 100644 --- a/store/migration/postgres/LATEST.sql +++ b/store/migration/postgres/LATEST.sql @@ -42,14 +42,6 @@ CREATE TABLE memo ( payload JSONB NOT NULL DEFAULT '{}' ); --- memo_organizer -CREATE TABLE memo_organizer ( - memo_id INTEGER NOT NULL, - user_id INTEGER NOT NULL, - pinned INTEGER NOT NULL DEFAULT 0, - UNIQUE(memo_id, user_id) -); - -- memo_relation CREATE TABLE memo_relation ( memo_id INTEGER NOT NULL, @@ -58,8 +50,8 @@ CREATE TABLE memo_relation ( UNIQUE(memo_id, related_memo_id, type) ); --- resource -CREATE TABLE resource ( +-- attachment +CREATE TABLE attachment ( id SERIAL PRIMARY KEY, uid TEXT NOT NULL UNIQUE, creator_id INTEGER NOT NULL, diff --git a/store/migration/sqlite/0.26/00__rename_resource_to_attachment.sql b/store/migration/sqlite/0.26/00__rename_resource_to_attachment.sql new file mode 100644 index 000000000..151cd6d3a --- /dev/null +++ b/store/migration/sqlite/0.26/00__rename_resource_to_attachment.sql @@ -0,0 +1,5 @@ +ALTER TABLE `resource` RENAME TO `attachment`; +DROP INDEX IF EXISTS `idx_resource_creator_id`; +CREATE INDEX `idx_attachment_creator_id` ON `attachment` (`creator_id`); +DROP INDEX IF EXISTS `idx_resource_memo_id`; +CREATE INDEX `idx_attachment_memo_id` ON `attachment` (`memo_id`); diff --git a/store/migration/sqlite/0.26/01__drop_memo_organizer.sql b/store/migration/sqlite/0.26/01__drop_memo_organizer.sql new file mode 100644 index 000000000..17c3579f2 --- /dev/null +++ b/store/migration/sqlite/0.26/01__drop_memo_organizer.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS memo_organizer; diff --git a/store/migration/sqlite/0.26/02__drop_indexes.sql b/store/migration/sqlite/0.26/02__drop_indexes.sql new file mode 100644 index 000000000..2923ba4fa --- /dev/null +++ b/store/migration/sqlite/0.26/02__drop_indexes.sql @@ -0,0 +1,4 @@ +DROP INDEX IF EXISTS idx_user_username; +DROP INDEX IF EXISTS idx_memo_creator_id; +DROP INDEX IF EXISTS idx_attachment_creator_id; +DROP INDEX IF EXISTS idx_attachment_memo_id; diff --git a/store/migration/sqlite/0.26/03__alter_user_role.sql b/store/migration/sqlite/0.26/03__alter_user_role.sql new file mode 100644 index 000000000..863097541 --- /dev/null +++ b/store/migration/sqlite/0.26/03__alter_user_role.sql @@ -0,0 +1,24 @@ +ALTER TABLE user RENAME TO user_old; + +CREATE TABLE user ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + username TEXT NOT NULL UNIQUE, + role TEXT NOT NULL DEFAULT 'USER', + email TEXT NOT NULL DEFAULT '', + nickname TEXT NOT NULL DEFAULT '', + password_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '' +); + +INSERT INTO user ( + id, created_ts, updated_ts, row_status, username, role, email, nickname, password_hash, avatar_url, description +) +SELECT + id, created_ts, updated_ts, row_status, username, role, email, nickname, password_hash, avatar_url, description +FROM user_old; + +DROP TABLE user_old; diff --git a/store/migration/sqlite/0.26/04__migrate_host_to_admin.sql b/store/migration/sqlite/0.26/04__migrate_host_to_admin.sql new file mode 100644 index 000000000..3e3f850ed --- /dev/null +++ b/store/migration/sqlite/0.26/04__migrate_host_to_admin.sql @@ -0,0 +1 @@ +UPDATE user SET role = 'ADMIN' WHERE role = 'HOST'; diff --git a/store/migration/sqlite/LATEST.sql b/store/migration/sqlite/LATEST.sql index 6a36e9338..8b70fa68c 100644 --- a/store/migration/sqlite/LATEST.sql +++ b/store/migration/sqlite/LATEST.sql @@ -13,7 +13,7 @@ CREATE TABLE user ( updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', username TEXT NOT NULL UNIQUE, - role TEXT NOT NULL CHECK (role IN ('HOST', 'ADMIN', 'USER')) DEFAULT 'USER', + role TEXT NOT NULL DEFAULT 'USER', email TEXT NOT NULL DEFAULT '', nickname TEXT NOT NULL DEFAULT '', password_hash TEXT NOT NULL, @@ -21,8 +21,6 @@ CREATE TABLE user ( description TEXT NOT NULL DEFAULT '' ); -CREATE INDEX idx_user_username ON user (username); - -- user_setting CREATE TABLE user_setting ( user_id INTEGER NOT NULL, @@ -45,16 +43,6 @@ CREATE TABLE memo ( payload TEXT NOT NULL DEFAULT '{}' ); -CREATE INDEX idx_memo_creator_id ON memo (creator_id); - --- memo_organizer -CREATE TABLE memo_organizer ( - memo_id INTEGER NOT NULL, - user_id INTEGER NOT NULL, - pinned INTEGER NOT NULL CHECK (pinned IN (0, 1)) DEFAULT 0, - UNIQUE(memo_id, user_id) -); - -- memo_relation CREATE TABLE memo_relation ( memo_id INTEGER NOT NULL, @@ -63,8 +51,8 @@ CREATE TABLE memo_relation ( UNIQUE(memo_id, related_memo_id, type) ); --- resource -CREATE TABLE resource ( +-- attachment +CREATE TABLE attachment ( id INTEGER PRIMARY KEY AUTOINCREMENT, uid TEXT NOT NULL UNIQUE, creator_id INTEGER NOT NULL, @@ -80,10 +68,6 @@ CREATE TABLE resource ( payload TEXT NOT NULL DEFAULT '{}' ); -CREATE INDEX idx_resource_creator_id ON resource (creator_id); - -CREATE INDEX idx_resource_memo_id ON resource (memo_id); - -- activity CREATE TABLE activity ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/store/migrator.go b/store/migrator.go index d5446fcab..1f2151209 100644 --- a/store/migrator.go +++ b/store/migrator.go @@ -21,7 +21,7 @@ import ( // Migration System Overview: // // The migration system handles database schema versioning and upgrades. -// Schema version is stored in instance_setting (formerly system_setting). +// Schema version is stored in system_setting. // // Migration Flow: // 1. preMigrate: Check if DB is initialized. If not, apply LATEST.sql @@ -30,9 +30,9 @@ import ( // 4. Migrate (demo mode): Seed database with demo data // // Version Tracking: -// - New installations: Schema version set in instance_setting immediately -// - Existing v0.22+ installations: Schema version tracked in instance_setting -// - Pre-v0.22 installations: Must upgrade to v0.25.x first (migration_history → instance_setting migration) +// - New installations: Schema version set in system_setting immediately +// - Existing v0.22+ installations: Schema version tracked in system_setting +// - Pre-v0.22 installations: Must upgrade to v0.25.x first (migration_history → system_setting migration) // // Migration Files: // - Location: store/migration/{driver}/{version}/NN__description.sql @@ -57,10 +57,6 @@ const ( // defaultSchemaVersion is used when schema version is empty or not set. // This handles edge cases for old installations without version tracking. defaultSchemaVersion = "0.0.0" - - // Mode constants for profile mode. - modeProd = "prod" - modeDemo = "demo" ) // getSchemaVersionOrDefault returns the schema version or default if empty. @@ -110,38 +106,36 @@ func (s *Store) Migrate(ctx context.Context) error { return errors.Wrap(err, "failed to pre-migrate") } - switch s.profile.Mode { - case modeProd: - instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx) - if err != nil { - return errors.Wrap(err, "failed to get instance basic setting") + instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx) + if err != nil { + return errors.Wrap(err, "failed to get instance basic setting") + } + currentSchemaVersion, err := s.GetCurrentSchemaVersion() + if err != nil { + return errors.Wrap(err, "failed to get current schema version") + } + // Check for downgrade (but skip if schema version is empty - that means fresh/old installation) + if !isVersionEmpty(instanceBasicSetting.SchemaVersion) && version.IsVersionGreaterThan(instanceBasicSetting.SchemaVersion, currentSchemaVersion) { + slog.Error("cannot downgrade schema version", + slog.String("databaseVersion", instanceBasicSetting.SchemaVersion), + slog.String("currentVersion", currentSchemaVersion), + ) + return errors.Errorf("cannot downgrade schema version from %s to %s", instanceBasicSetting.SchemaVersion, currentSchemaVersion) + } + // Apply migrations if needed (including when schema version is empty) + if isVersionEmpty(instanceBasicSetting.SchemaVersion) || version.IsVersionGreaterThan(currentSchemaVersion, instanceBasicSetting.SchemaVersion) { + if err := s.applyMigrations(ctx, instanceBasicSetting.SchemaVersion, currentSchemaVersion); err != nil { + return errors.Wrap(err, "failed to apply migrations") } - currentSchemaVersion, err := s.GetCurrentSchemaVersion() - if err != nil { - return errors.Wrap(err, "failed to get current schema version") - } - // Check for downgrade (but skip if schema version is empty - that means fresh/old installation) - if !isVersionEmpty(instanceBasicSetting.SchemaVersion) && version.IsVersionGreaterThan(instanceBasicSetting.SchemaVersion, currentSchemaVersion) { - slog.Error("cannot downgrade schema version", - slog.String("databaseVersion", instanceBasicSetting.SchemaVersion), - slog.String("currentVersion", currentSchemaVersion), - ) - return errors.Errorf("cannot downgrade schema version from %s to %s", instanceBasicSetting.SchemaVersion, currentSchemaVersion) - } - // Apply migrations if needed (including when schema version is empty) - if isVersionEmpty(instanceBasicSetting.SchemaVersion) || version.IsVersionGreaterThan(currentSchemaVersion, instanceBasicSetting.SchemaVersion) { - if err := s.applyMigrations(ctx, instanceBasicSetting.SchemaVersion, currentSchemaVersion); err != nil { - return errors.Wrap(err, "failed to apply migrations") - } - } - case modeDemo: + } + + if s.profile.Demo { // In demo mode, we should seed the database. if err := s.seed(ctx); err != nil { return errors.Wrap(err, "failed to seed") } - default: - // For other modes (like dev), no special migration handling needed } + return nil } @@ -255,14 +249,11 @@ func (s *Store) preMigrate(ctx context.Context) error { } } - if s.profile.Mode == modeProd { - if err := s.checkMinimumUpgradeVersion(ctx); err != nil { - return err // Error message is already descriptive, don't wrap it - } + if err := s.checkMinimumUpgradeVersion(ctx); err != nil { + return err // Error message is already descriptive, don't wrap it } return nil } - func (s *Store) getMigrationBasePath() string { return fmt.Sprintf("migration/%s/", s.profile.Driver) } @@ -308,7 +299,7 @@ func (s *Store) seed(ctx context.Context) error { } func (s *Store) GetCurrentSchemaVersion() (string, error) { - currentVersion := version.GetCurrentVersion(s.profile.Mode) + currentVersion := version.GetCurrentVersion() minorVersion := version.GetMinorVersion(currentVersion) filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion)) if err != nil { @@ -373,7 +364,7 @@ func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion st // checkMinimumUpgradeVersion verifies the installation meets minimum version requirements for upgrade. // For very old installations (< v0.22.0), users must upgrade to v0.25.x first before upgrading to current version. -// This is necessary because schema version tracking was moved from migration_history to instance_setting in v0.22.0. +// This is necessary because schema version tracking was moved from migration_history to system_setting in v0.22.0. func (s *Store) checkMinimumUpgradeVersion(ctx context.Context) error { instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx) if err != nil { @@ -401,7 +392,7 @@ func (s *Store) checkMinimumUpgradeVersion(ctx context.Context) error { "2. Start the server and verify it works\n"+ "3. Then upgrade to the latest version\n\n"+ "This is required because schema version tracking was moved from migration_history\n"+ - "to instance_setting in v0.22.0. The intermediate upgrade handles this migration safely.", + "to system_setting in v0.22.0. The intermediate upgrade handles this migration safely.", schemaVersion, currentVersion, ) diff --git a/store/seed/DEMO_DATA_GUIDE.md b/store/seed/DEMO_DATA_GUIDE.md index 81b46aef7..afd0ea4de 100644 --- a/store/seed/DEMO_DATA_GUIDE.md +++ b/store/seed/DEMO_DATA_GUIDE.md @@ -10,7 +10,7 @@ The demo data includes **6 carefully selected memos** that showcase the key feat - **Username**: `demo` - **Password**: `secret` (default password) -- **Role**: HOST +- **Role**: ADMIN - **Nickname**: Demo User ## Demo Memos (6 total) @@ -174,10 +174,10 @@ To run with demo data: ```bash # Start in demo mode -go run ./cmd/memos --mode demo --port 8081 +go run ./cmd/memos --demo --port 8081 # Or use the binary -./memos --mode demo +./memos --demo # Demo database location ./build/memos_demo.db @@ -198,7 +198,7 @@ Login with: - All memos are set to PUBLIC visibility - **Two memos are pinned**: Welcome (#1) and Sponsor (#6) -- User has HOST role to showcase all features +- User has ADMIN role to showcase all features - Reactions are distributed across memos - One memo relation demonstrates linking - Content is optimized for the compact markdown styles diff --git a/store/seed/sqlite/00__reset.sql b/store/seed/sqlite/00__reset.sql deleted file mode 100644 index 65e32e7b6..000000000 --- a/store/seed/sqlite/00__reset.sql +++ /dev/null @@ -1,11 +0,0 @@ -DELETE FROM system_setting; -DELETE FROM user; -DELETE FROM user_setting; -DELETE FROM memo; -DELETE FROM memo_organizer; -DELETE FROM memo_relation; -DELETE FROM resource; -DELETE FROM activity; -DELETE FROM idp; -DELETE FROM inbox; -DELETE FROM reaction; diff --git a/store/seed/sqlite/01__dump.sql b/store/seed/sqlite/01__dump.sql index 457703998..0fa8fe41d 100644 --- a/store/seed/sqlite/01__dump.sql +++ b/store/seed/sqlite/01__dump.sql @@ -1,5 +1,5 @@ -- Demo User -INSERT INTO user (id,username,role,nickname,password_hash) VALUES(1,'demo','HOST','Demo User','$2a$10$c.slEVgf5b/3BnAWlLb/vOu7VVSOKJ4ljwMe9xzlx9IhKnvAsJYM6'); +INSERT INTO user (id,username,role,nickname,password_hash) VALUES(1,'demo','ADMIN','Demo User','$2a$10$c.slEVgf5b/3BnAWlLb/vOu7VVSOKJ4ljwMe9xzlx9IhKnvAsJYM6'); -- Welcome Memo (Pinned) INSERT INTO memo (id,uid,creator_id,content,visibility,pinned,payload) VALUES(1,'welcome2memos001',1,replace('# Welcome to Memos!\\n\\nA privacy-first, lightweight note-taking service. Easily capture and share your great thoughts.\\n\\n## Key Features\\n\\n- **Privacy First**: Your data stays with you\\n- **Markdown Support**: Full CommonMark + GFM syntax\\n- **Quick Capture**: Jot down thoughts instantly\\n- **Organize with Tags**: Use #tags to categorize\\n- **Open Source**: Free and open source software\\n\\n---\\n\\nStart exploring the demo memos below to see what you can do! #welcome #getting-started','\\n',char(10)),'PUBLIC',1,'{"tags":["welcome","getting-started"],"property":{"hasLink":false}}'); diff --git a/store/test/activity_test.go b/store/test/activity_test.go index 1328199e8..20eecca52 100644 --- a/store/test/activity_test.go +++ b/store/test/activity_test.go @@ -11,6 +11,7 @@ import ( ) func TestActivityStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -34,6 +35,7 @@ func TestActivityStore(t *testing.T) { } func TestActivityGetByID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -63,6 +65,7 @@ func TestActivityGetByID(t *testing.T) { } func TestActivityListMultiple(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -99,3 +102,279 @@ func TestActivityListMultiple(t *testing.T) { ts.Close() } + +func TestActivityListByType(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create activities with MEMO_COMMENT type + _, err = ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + + _, err = ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + + // List by type + activityType := store.ActivityTypeMemoComment + activities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &activityType}) + require.NoError(t, err) + require.Len(t, activities, 2) + for _, activity := range activities { + require.Equal(t, store.ActivityTypeMemoComment, activity.Type) + } + + ts.Close() +} + +func TestActivityPayloadMemoComment(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create activity with MemoComment payload + memoID := int32(123) + relatedMemoID := int32(456) + activity, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{ + MemoComment: &storepb.ActivityMemoCommentPayload{ + MemoId: memoID, + RelatedMemoId: relatedMemoID, + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, activity.Payload) + require.NotNil(t, activity.Payload.MemoComment) + require.Equal(t, memoID, activity.Payload.MemoComment.MemoId) + require.Equal(t, relatedMemoID, activity.Payload.MemoComment.RelatedMemoId) + + // Verify payload is preserved when listing + found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID}) + require.NoError(t, err) + require.NotNil(t, found.Payload.MemoComment) + require.Equal(t, memoID, found.Payload.MemoComment.MemoId) + require.Equal(t, relatedMemoID, found.Payload.MemoComment.RelatedMemoId) + + ts.Close() +} + +func TestActivityEmptyPayload(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create activity with empty payload + activity, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + require.NotNil(t, activity.Payload) + + // Verify empty payload is handled correctly + found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID}) + require.NoError(t, err) + require.NotNil(t, found.Payload) + require.Nil(t, found.Payload.MemoComment) + + ts.Close() +} + +func TestActivityLevel(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create activity with INFO level + activity, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + require.Equal(t, store.ActivityLevelInfo, activity.Level) + + // Verify level is preserved when listing + found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID}) + require.NoError(t, err) + require.Equal(t, store.ActivityLevelInfo, found.Level) + + ts.Close() +} + +func TestActivityCreatorID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user1, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser) + require.NoError(t, err) + + // Create activity for user1 + activity1, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user1.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + require.Equal(t, user1.ID, activity1.CreatorID) + + // Create activity for user2 + activity2, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user2.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + require.Equal(t, user2.ID, activity2.CreatorID) + + // List all and verify creator IDs + activities, err := ts.ListActivities(ctx, &store.FindActivity{}) + require.NoError(t, err) + require.Len(t, activities, 2) + + ts.Close() +} + +func TestActivityCreatedTs(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + activity, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + require.NotZero(t, activity.CreatedTs) + + // Verify timestamp is preserved when listing + found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID}) + require.NoError(t, err) + require.Equal(t, activity.CreatedTs, found.CreatedTs) + + ts.Close() +} + +func TestActivityListEmpty(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // List activities when none exist + activities, err := ts.ListActivities(ctx, &store.FindActivity{}) + require.NoError(t, err) + require.Len(t, activities, 0) + + ts.Close() +} + +func TestActivityListWithIDAndType(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + activity, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + }) + require.NoError(t, err) + + // List with both ID and Type filters + activityType := store.ActivityTypeMemoComment + activities, err := ts.ListActivities(ctx, &store.FindActivity{ + ID: &activity.ID, + Type: &activityType, + }) + require.NoError(t, err) + require.Len(t, activities, 1) + require.Equal(t, activity.ID, activities[0].ID) + + ts.Close() +} + +func TestActivityPayloadComplexMemoComment(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create a memo first to use its ID + memo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "test-memo-for-activity", + CreatorID: user.ID, + Content: "Test memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create comment memo + commentMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "comment-memo", + CreatorID: user.ID, + Content: "This is a comment", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create activity with real memo IDs + activity, err := ts.CreateActivity(ctx, &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{ + MemoComment: &storepb.ActivityMemoCommentPayload{ + MemoId: memo.ID, + RelatedMemoId: commentMemo.ID, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, memo.ID, activity.Payload.MemoComment.MemoId) + require.Equal(t, commentMemo.ID, activity.Payload.MemoComment.RelatedMemoId) + + // Verify payload is preserved + found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID}) + require.NoError(t, err) + require.Equal(t, memo.ID, found.Payload.MemoComment.MemoId) + require.Equal(t, commentMemo.ID, found.Payload.MemoComment.RelatedMemoId) + + ts.Close() +} diff --git a/store/test/attachment_filter_test.go b/store/test/attachment_filter_test.go index 3ae64b024..a2f6c6af3 100644 --- a/store/test/attachment_filter_test.go +++ b/store/test/attachment_filter_test.go @@ -13,6 +13,7 @@ import ( // ============================================================================= func TestAttachmentFilterFilenameContains(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -35,6 +36,7 @@ func TestAttachmentFilterFilenameContains(t *testing.T) { } func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -51,6 +53,7 @@ func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) { } func TestAttachmentFilterFilenameUnicode(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -67,6 +70,7 @@ func TestAttachmentFilterFilenameUnicode(t *testing.T) { // ============================================================================= func TestAttachmentFilterMimeTypeEquals(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -86,6 +90,7 @@ func TestAttachmentFilterMimeTypeEquals(t *testing.T) { } func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -98,6 +103,7 @@ func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) { } func TestAttachmentFilterMimeTypeInList(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -121,6 +127,7 @@ func TestAttachmentFilterMimeTypeInList(t *testing.T) { // ============================================================================= func TestAttachmentFilterCreateTimeComparison(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -141,6 +148,7 @@ func TestAttachmentFilterCreateTimeComparison(t *testing.T) { } func TestAttachmentFilterCreateTimeWithNow(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -156,6 +164,7 @@ func TestAttachmentFilterCreateTimeWithNow(t *testing.T) { } func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -175,6 +184,7 @@ func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) { } func TestAttachmentFilterAllComparisonOperators(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -203,6 +213,7 @@ func TestAttachmentFilterAllComparisonOperators(t *testing.T) { // ============================================================================= func TestAttachmentFilterMemoIdEquals(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContextWithUser(t) defer tc.Close() @@ -218,6 +229,7 @@ func TestAttachmentFilterMemoIdEquals(t *testing.T) { } func TestAttachmentFilterMemoIdNotEquals(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContextWithUser(t) defer tc.Close() @@ -238,6 +250,7 @@ func TestAttachmentFilterMemoIdNotEquals(t *testing.T) { // ============================================================================= func TestAttachmentFilterLogicalAnd(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -251,6 +264,7 @@ func TestAttachmentFilterLogicalAnd(t *testing.T) { } func TestAttachmentFilterLogicalOr(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -263,6 +277,7 @@ func TestAttachmentFilterLogicalOr(t *testing.T) { } func TestAttachmentFilterLogicalNot(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -275,6 +290,7 @@ func TestAttachmentFilterLogicalNot(t *testing.T) { } func TestAttachmentFilterComplexLogical(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -291,6 +307,7 @@ func TestAttachmentFilterComplexLogical(t *testing.T) { // ============================================================================= func TestAttachmentFilterMultipleFilters(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -310,6 +327,7 @@ func TestAttachmentFilterMultipleFilters(t *testing.T) { // ============================================================================= func TestAttachmentFilterNoMatches(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() @@ -320,6 +338,7 @@ func TestAttachmentFilterNoMatches(t *testing.T) { } func TestAttachmentFilterNullMemoId(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContextWithUser(t) defer tc.Close() @@ -343,6 +362,7 @@ func TestAttachmentFilterNullMemoId(t *testing.T) { } func TestAttachmentFilterEmptyFilename(t *testing.T) { + t.Parallel() tc := NewAttachmentFilterTestContext(t) defer tc.Close() diff --git a/store/test/attachment_test.go b/store/test/attachment_test.go index 1886f75d5..12cb23be2 100644 --- a/store/test/attachment_test.go +++ b/store/test/attachment_test.go @@ -12,6 +12,7 @@ import ( ) func TestAttachmentStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) _, err := ts.CreateAttachment(ctx, &store.Attachment{ @@ -64,6 +65,7 @@ func TestAttachmentStore(t *testing.T) { } func TestAttachmentStoreWithFilter(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -123,6 +125,7 @@ func TestAttachmentStoreWithFilter(t *testing.T) { } func TestAttachmentUpdate(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -153,6 +156,7 @@ func TestAttachmentUpdate(t *testing.T) { } func TestAttachmentGetByUID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -183,6 +187,7 @@ func TestAttachmentGetByUID(t *testing.T) { } func TestAttachmentListWithPagination(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -222,6 +227,7 @@ func TestAttachmentListWithPagination(t *testing.T) { } func TestAttachmentInvalidUID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) diff --git a/store/test/containers.go b/store/test/containers.go index bd65ea40e..8ad5d20dc 100644 --- a/store/test/containers.go +++ b/store/test/containers.go @@ -4,16 +4,19 @@ import ( "context" "database/sql" "fmt" + "os" "strings" "sync" "sync/atomic" "testing" "time" + "github.com/docker/docker/api/types/container" "github.com/pkg/errors" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/network" "github.com/testcontainers/testcontainers-go/wait" // Database drivers for connection verification. @@ -24,23 +27,51 @@ import ( const ( testUser = "root" testPassword = "test" + + // Memos container settings for migration testing. + MemosDockerImage = "neosmemo/memos" + StableMemosVersion = "stable" // Always points to the latest stable release ) var ( - mysqlContainer *mysql.MySQLContainer - postgresContainer *postgres.PostgresContainer + mysqlContainer atomic.Pointer[mysql.MySQLContainer] + postgresContainer atomic.Pointer[postgres.PostgresContainer] mysqlOnce sync.Once postgresOnce sync.Once - mysqlBaseDSN string - postgresBaseDSN string + mysqlBaseDSN atomic.Value // stores string + postgresBaseDSN atomic.Value // stores string dbCounter atomic.Int64 + dbCreationMutex sync.Mutex // Protects database creation operations + + // Network for container communication. + testDockerNetwork atomic.Pointer[testcontainers.DockerNetwork] + testNetworkOnce sync.Once ) +// getTestNetwork creates or returns the shared Docker network for container communication. +func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) { + var networkErr error + testNetworkOnce.Do(func() { + nw, err := network.New(ctx, network.WithDriver("bridge")) + if err != nil { + networkErr = err + return + } + testDockerNetwork.Store(nw) + }) + return testDockerNetwork.Load(), networkErr +} + // GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test. func GetMySQLDSN(t *testing.T) string { ctx := context.Background() mysqlOnce.Do(func() { + nw, err := getTestNetwork(ctx) + if err != nil { + t.Fatalf("failed to create test network: %v", err) + } + container, err := mysql.Run(ctx, "mysql:8", mysql.WithDatabase("init_db"), @@ -55,11 +86,12 @@ func GetMySQLDSN(t *testing.T) string { wait.ForListeningPort("3306/tcp"), ).WithDeadline(120*time.Second), ), + network.WithNetwork(nil, nw), ) if err != nil { t.Fatalf("failed to start MySQL container: %v", err) } - mysqlContainer = container + mysqlContainer.Store(container) dsn, err := container.ConnectionString(ctx, "multiStatements=true") if err != nil { @@ -70,16 +102,21 @@ func GetMySQLDSN(t *testing.T) string { t.Fatalf("MySQL not ready for connections: %v", err) } - mysqlBaseDSN = dsn + mysqlBaseDSN.Store(dsn) }) - if mysqlBaseDSN == "" { + dsn, ok := mysqlBaseDSN.Load().(string) + if !ok || dsn == "" { t.Fatal("MySQL container failed to start in a previous test") } + // Serialize database creation to avoid "table already exists" race conditions + dbCreationMutex.Lock() + defer dbCreationMutex.Unlock() + // Create a fresh database for this test dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1)) - db, err := sql.Open("mysql", mysqlBaseDSN) + db, err := sql.Open("mysql", dsn) if err != nil { t.Fatalf("failed to connect to MySQL: %v", err) } @@ -90,7 +127,7 @@ func GetMySQLDSN(t *testing.T) string { } // Return DSN pointing to the new database - return strings.Replace(mysqlBaseDSN, "/init_db?", "/"+dbName+"?", 1) + return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1) } // waitForDB polls the database until it's ready or timeout is reached. @@ -130,6 +167,11 @@ func GetPostgresDSN(t *testing.T) string { ctx := context.Background() postgresOnce.Do(func() { + nw, err := getTestNetwork(ctx) + if err != nil { + t.Fatalf("failed to create test network: %v", err) + } + container, err := postgres.Run(ctx, "postgres:18", postgres.WithDatabase("init_db"), @@ -141,11 +183,12 @@ func GetPostgresDSN(t *testing.T) string { wait.ForListeningPort("5432/tcp"), ).WithDeadline(120*time.Second), ), + network.WithNetwork(nil, nw), ) if err != nil { t.Fatalf("failed to start PostgreSQL container: %v", err) } - postgresContainer = container + postgresContainer.Store(container) dsn, err := container.ConnectionString(ctx, "sslmode=disable") if err != nil { @@ -156,16 +199,21 @@ func GetPostgresDSN(t *testing.T) string { t.Fatalf("PostgreSQL not ready for connections: %v", err) } - postgresBaseDSN = dsn + postgresBaseDSN.Store(dsn) }) - if postgresBaseDSN == "" { + dsn, ok := postgresBaseDSN.Load().(string) + if !ok || dsn == "" { t.Fatal("PostgreSQL container failed to start in a previous test") } + // Serialize database creation to avoid "table already exists" race conditions + dbCreationMutex.Lock() + defer dbCreationMutex.Unlock() + // Create a fresh database for this test dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1)) - db, err := sql.Open("postgres", postgresBaseDSN) + db, err := sql.Open("postgres", dsn) if err != nil { t.Fatalf("failed to connect to PostgreSQL: %v", err) } @@ -176,17 +224,94 @@ func GetPostgresDSN(t *testing.T) string { } // Return DSN pointing to the new database - return strings.Replace(postgresBaseDSN, "/init_db?", "/"+dbName+"?", 1) + return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1) } -// TerminateContainers cleans up all running containers. +// TerminateContainers cleans up all running containers and network. // This is typically called from TestMain. func TerminateContainers() { ctx := context.Background() - if mysqlContainer != nil { - _ = mysqlContainer.Terminate(ctx) + if container := mysqlContainer.Load(); container != nil { + _ = container.Terminate(ctx) } - if postgresContainer != nil { - _ = postgresContainer.Terminate(ctx) + if container := postgresContainer.Load(); container != nil { + _ = container.Terminate(ctx) + } + if network := testDockerNetwork.Load(); network != nil { + _ = network.Remove(ctx) } } + +// MemosContainerConfig holds configuration for starting a Memos container. +type MemosContainerConfig struct { + Version string // Memos version tag (e.g., "0.24.0") + Driver string // Database driver: sqlite, mysql, postgres + DSN string // Database DSN (for mysql/postgres) + DataDir string // Host directory to mount for SQLite data +} + +// MemosStartupWaitStrategy defines the wait strategy for Memos container startup. +// Uses regex to match various log message formats across versions. +var MemosStartupWaitStrategy = wait.ForAll( + wait.ForLog("(started successfully|has been started on port)").AsRegexp(), + wait.ForListeningPort("5230/tcp"), +).WithDeadline(180 * time.Second) + +// StartMemosContainer starts a Memos container for migration testing. +// For SQLite, it mounts the dataDir to /var/opt/memos. +func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcontainers.Container, error) { + env := map[string]string{ + "MEMOS_MODE": "prod", + } + + var opts []testcontainers.ContainerCustomizer + + switch cfg.Driver { + case "sqlite": + env["MEMOS_DRIVER"] = "sqlite" + opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) { + hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos")) + })) + default: + return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver) + } + + req := testcontainers.ContainerRequest{ + Image: fmt.Sprintf("%s:%s", MemosDockerImage, cfg.Version), + Env: env, + ExposedPorts: []string{"5230/tcp"}, + WaitingFor: MemosStartupWaitStrategy, + User: fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), + } + + // Use local image if specified + if cfg.Version == "local" { + if os.Getenv("MEMOS_TEST_IMAGE_BUILT") == "1" { + req.Image = "memos-test:local" + } else { + req.FromDockerfile = testcontainers.FromDockerfile{ + Context: "../../", + Dockerfile: "Dockerfile", + } + } + } + + genericReq := testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + } + + // Apply options + for _, opt := range opts { + if err := opt.Customize(&genericReq); err != nil { + return nil, errors.Wrap(err, "failed to apply container option") + } + } + + ctr, err := testcontainers.GenericContainer(ctx, genericReq) + if err != nil { + return nil, errors.Wrap(err, "failed to start memos container") + } + + return ctr, nil +} diff --git a/store/test/idp_test.go b/store/test/idp_test.go index 0522454f8..8f2f1958b 100644 --- a/store/test/idp_test.go +++ b/store/test/idp_test.go @@ -11,6 +11,7 @@ import ( ) func TestIdentityProviderStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ @@ -58,3 +59,396 @@ func TestIdentityProviderStore(t *testing.T) { require.Equal(t, 0, len(idpList)) ts.Close() } + +func TestIdentityProviderGetByID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create IDP + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + require.NoError(t, err) + + // Get by ID + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.NotNil(t, found) + require.Equal(t, idp.Id, found.Id) + require.Equal(t, idp.Name, found.Name) + + // Get by non-existent ID + nonExistentID := int32(99999) + notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID}) + require.NoError(t, err) + require.Nil(t, notFound) + + ts.Close() +} + +func TestIdentityProviderListMultiple(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create multiple IDPs + _, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth")) + require.NoError(t, err) + _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth")) + require.NoError(t, err) + _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth")) + require.NoError(t, err) + + // List all + idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) + require.NoError(t, err) + require.Len(t, idpList, 3) + + ts.Close() +} + +func TestIdentityProviderListByID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create multiple IDPs + idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth")) + require.NoError(t, err) + _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth")) + require.NoError(t, err) + + // List by specific ID + idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{ID: &idp1.Id}) + require.NoError(t, err) + require.Len(t, idpList, 1) + require.Equal(t, "GitHub OAuth", idpList[0].Name) + + ts.Close() +} + +func TestIdentityProviderUpdateName(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name")) + require.NoError(t, err) + require.Equal(t, "Original Name", idp.Name) + + // Update name + newName := "Updated Name" + updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{ + ID: idp.Id, + Type: storepb.IdentityProvider_OAUTH2, + Name: &newName, + }) + require.NoError(t, err) + require.Equal(t, "Updated Name", updated.Name) + + // Verify update persisted + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Equal(t, "Updated Name", found.Name) + + ts.Close() +} + +func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + require.NoError(t, err) + require.Equal(t, "", idp.IdentifierFilter) + + // Update identifier filter + newFilter := "@example.com$" + updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{ + ID: idp.Id, + Type: storepb.IdentityProvider_OAUTH2, + IdentifierFilter: &newFilter, + }) + require.NoError(t, err) + require.Equal(t, "@example.com$", updated.IdentifierFilter) + + // Verify update persisted + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Equal(t, "@example.com$", found.IdentifierFilter) + + ts.Close() +} + +func TestIdentityProviderUpdateConfig(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + require.NoError(t, err) + + // Update config + newConfig := &storepb.IdentityProviderConfig{ + Config: &storepb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &storepb.OAuth2Config{ + ClientId: "new_client_id", + ClientSecret: "new_client_secret", + AuthUrl: "https://newprovider.com/auth", + TokenUrl: "https://newprovider.com/token", + UserInfoUrl: "https://newprovider.com/user", + Scopes: []string{"openid", "profile", "email"}, + FieldMapping: &storepb.FieldMapping{ + Identifier: "sub", + DisplayName: "name", + Email: "email", + }, + }, + }, + } + updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{ + ID: idp.Id, + Type: storepb.IdentityProvider_OAUTH2, + Config: newConfig, + }) + require.NoError(t, err) + require.Equal(t, "new_client_id", updated.Config.GetOauth2Config().ClientId) + require.Equal(t, "new_client_secret", updated.Config.GetOauth2Config().ClientSecret) + require.Contains(t, updated.Config.GetOauth2Config().Scopes, "openid") + + // Verify update persisted + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Equal(t, "new_client_id", found.Config.GetOauth2Config().ClientId) + + ts.Close() +} + +func TestIdentityProviderUpdateMultipleFields(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original")) + require.NoError(t, err) + + // Update multiple fields at once + newName := "Updated IDP" + newFilter := "^admin@" + updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{ + ID: idp.Id, + Type: storepb.IdentityProvider_OAUTH2, + Name: &newName, + IdentifierFilter: &newFilter, + }) + require.NoError(t, err) + require.Equal(t, "Updated IDP", updated.Name) + require.Equal(t, "^admin@", updated.IdentifierFilter) + + ts.Close() +} + +func TestIdentityProviderDelete(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + require.NoError(t, err) + + // Delete + err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id}) + require.NoError(t, err) + + // Verify deletion + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Nil(t, found) + + ts.Close() +} + +func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create multiple IDPs + idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1")) + require.NoError(t, err) + idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2")) + require.NoError(t, err) + + // Delete first one + err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp1.Id}) + require.NoError(t, err) + + // Verify second still exists + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp2.Id}) + require.NoError(t, err) + require.NotNil(t, found) + require.Equal(t, "IDP 2", found.Name) + + // Verify list only contains second + idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) + require.NoError(t, err) + require.Len(t, idpList, 1) + + ts.Close() +} + +func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create IDP with multiple scopes + idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Name: "Multi-Scope OAuth", + Type: storepb.IdentityProvider_OAUTH2, + Config: &storepb.IdentityProviderConfig{ + Config: &storepb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &storepb.OAuth2Config{ + ClientId: "client_id", + ClientSecret: "client_secret", + AuthUrl: "https://provider.com/auth", + TokenUrl: "https://provider.com/token", + UserInfoUrl: "https://provider.com/userinfo", + Scopes: []string{"openid", "profile", "email", "groups"}, + FieldMapping: &storepb.FieldMapping{ + Identifier: "sub", + DisplayName: "name", + Email: "email", + }, + }, + }, + }, + }) + require.NoError(t, err) + + // Verify scopes are preserved + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Len(t, found.Config.GetOauth2Config().Scopes, 4) + require.Contains(t, found.Config.GetOauth2Config().Scopes, "openid") + require.Contains(t, found.Config.GetOauth2Config().Scopes, "groups") + + ts.Close() +} + +func TestIdentityProviderFieldMapping(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Create IDP with custom field mapping + idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Name: "Custom Field Mapping", + Type: storepb.IdentityProvider_OAUTH2, + Config: &storepb.IdentityProviderConfig{ + Config: &storepb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &storepb.OAuth2Config{ + ClientId: "client_id", + ClientSecret: "client_secret", + AuthUrl: "https://provider.com/auth", + TokenUrl: "https://provider.com/token", + UserInfoUrl: "https://provider.com/userinfo", + Scopes: []string{"login"}, + FieldMapping: &storepb.FieldMapping{ + Identifier: "preferred_username", + DisplayName: "full_name", + Email: "email_address", + }, + }, + }, + }, + }) + require.NoError(t, err) + + // Verify field mapping is preserved + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Equal(t, "preferred_username", found.Config.GetOauth2Config().FieldMapping.Identifier) + require.Equal(t, "full_name", found.Config.GetOauth2Config().FieldMapping.DisplayName) + require.Equal(t, "email_address", found.Config.GetOauth2Config().FieldMapping.Email) + + ts.Close() +} + +func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + testCases := []struct { + name string + filter string + }{ + {"Domain filter", "@company\\.com$"}, + {"Prefix filter", "^admin_"}, + {"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"}, + {"Empty filter", ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Name: tc.name, + Type: storepb.IdentityProvider_OAUTH2, + IdentifierFilter: tc.filter, + Config: &storepb.IdentityProviderConfig{ + Config: &storepb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &storepb.OAuth2Config{ + ClientId: "client_id", + ClientSecret: "client_secret", + AuthUrl: "https://provider.com/auth", + TokenUrl: "https://provider.com/token", + UserInfoUrl: "https://provider.com/userinfo", + Scopes: []string{"login"}, + FieldMapping: &storepb.FieldMapping{ + Identifier: "sub", + }, + }, + }, + }, + }) + require.NoError(t, err) + + found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id}) + require.NoError(t, err) + require.Equal(t, tc.filter, found.IdentifierFilter) + + // Cleanup + err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id}) + require.NoError(t, err) + }) + } + + ts.Close() +} + +// Helper function to create a test OAuth2 IDP. +func createTestOAuth2IDP(name string) *storepb.IdentityProvider { + return &storepb.IdentityProvider{ + Name: name, + Type: storepb.IdentityProvider_OAUTH2, + IdentifierFilter: "", + Config: &storepb.IdentityProviderConfig{ + Config: &storepb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &storepb.OAuth2Config{ + ClientId: "client_id", + ClientSecret: "client_secret", + AuthUrl: "https://provider.com/auth", + TokenUrl: "https://provider.com/token", + UserInfoUrl: "https://provider.com/userinfo", + Scopes: []string{"login"}, + FieldMapping: &storepb.FieldMapping{ + Identifier: "login", + DisplayName: "name", + Email: "email", + }, + }, + }, + }, + } +} diff --git a/store/test/inbox_test.go b/store/test/inbox_test.go index 0c74bc104..8af2c7082 100644 --- a/store/test/inbox_test.go +++ b/store/test/inbox_test.go @@ -11,6 +11,7 @@ import ( ) func TestInboxStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -52,3 +53,540 @@ func TestInboxStore(t *testing.T) { require.Equal(t, 0, len(inboxes)) ts.Close() } + +func TestInboxListByID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + inbox, err := ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // List by ID + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, inbox.ID, inboxes[0].ID) + + // List by non-existent ID + nonExistentID := int32(99999) + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ID: &nonExistentID}) + require.NoError(t, err) + require.Len(t, inboxes, 0) + + ts.Close() +} + +func TestInboxListBySenderID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user1, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser) + require.NoError(t, err) + + // Create inbox from system bot (senderID = 0) + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user1.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // Create inbox from user2 + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: user2.ID, + ReceiverID: user1.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // List by sender ID = user2 + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{SenderID: &user2.ID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, user2.ID, inboxes[0].SenderID) + + // List by sender ID = 0 (system bot) + systemBotID := int32(0) + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{SenderID: &systemBotID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, int32(0), inboxes[0].SenderID) + + ts.Close() +} + +func TestInboxListByStatus(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create UNREAD inbox + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // Create another inbox and archive it + inbox2, err := ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + _, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox2.ID, Status: store.ARCHIVED}) + require.NoError(t, err) + + // List by UNREAD status + unreadStatus := store.UNREAD + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{Status: &unreadStatus}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, store.UNREAD, inboxes[0].Status) + + // List by ARCHIVED status + archivedStatus := store.ARCHIVED + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{Status: &archivedStatus}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, store.ARCHIVED, inboxes[0].Status) + + ts.Close() +} + +func TestInboxListByMessageType(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create MEMO_COMMENT inboxes + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // List by MEMO_COMMENT type + memoCommentType := storepb.InboxMessage_MEMO_COMMENT + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{MessageType: &memoCommentType}) + require.NoError(t, err) + require.Len(t, inboxes, 2) + for _, inbox := range inboxes { + require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type) + } + + ts.Close() +} + +func TestInboxListPagination(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create 5 inboxes + for i := 0; i < 5; i++ { + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + } + + // Test Limit only + limit := 3 + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + Limit: &limit, + }) + require.NoError(t, err) + require.Len(t, inboxes, 3) + + // Test Limit + Offset (offset requires limit in the implementation) + limit = 2 + offset := 2 + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + Limit: &limit, + Offset: &offset, + }) + require.NoError(t, err) + require.Len(t, inboxes, 2) + + // Test Limit + Offset skipping to end + limit = 10 + offset = 3 + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + Limit: &limit, + Offset: &offset, + }) + require.NoError(t, err) + require.Len(t, inboxes, 2) // Only 2 remaining after offset of 3 + + ts.Close() +} + +func TestInboxListCombinedFilters(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user1, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser) + require.NoError(t, err) + + // Create various inboxes + // user2 -> user1, MEMO_COMMENT, UNREAD + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: user2.ID, + ReceiverID: user1.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // user2 -> user1, TYPE_UNSPECIFIED, UNREAD + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: user2.ID, + ReceiverID: user1.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED}, + }) + require.NoError(t, err) + + // system -> user1, MEMO_COMMENT, ARCHIVED + inbox3, err := ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user1.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + _, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox3.ID, Status: store.ARCHIVED}) + require.NoError(t, err) + + // Combined filter: ReceiverID + SenderID + Status + unreadStatus := store.UNREAD + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user1.ID, + SenderID: &user2.ID, + Status: &unreadStatus, + }) + require.NoError(t, err) + require.Len(t, inboxes, 2) + + // Combined filter: ReceiverID + MessageType + Status + memoCommentType := storepb.InboxMessage_MEMO_COMMENT + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user1.ID, + MessageType: &memoCommentType, + Status: &unreadStatus, + }) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, user2.ID, inboxes[0].SenderID) + + ts.Close() +} + +func TestInboxMessagePayload(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create inbox with message payload containing activity ID + activityID := int32(123) + inbox, err := ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{ + Type: storepb.InboxMessage_MEMO_COMMENT, + ActivityId: &activityID, + }, + }) + require.NoError(t, err) + require.NotNil(t, inbox.Message) + require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type) + require.Equal(t, activityID, *inbox.Message.ActivityId) + + // List and verify payload is preserved + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, activityID, *inboxes[0].Message.ActivityId) + + ts.Close() +} + +func TestInboxUpdateStatus(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + inbox, err := ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + require.Equal(t, store.UNREAD, inbox.Status) + + // Update to ARCHIVED + updated, err := ts.UpdateInbox(ctx, &store.UpdateInbox{ + ID: inbox.ID, + Status: store.ARCHIVED, + }) + require.NoError(t, err) + require.Equal(t, store.ARCHIVED, updated.Status) + require.Equal(t, inbox.ID, updated.ID) + + // Verify the update persisted + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, store.ARCHIVED, inboxes[0].Status) + + ts.Close() +} + +func TestInboxListByMessageTypeMultipleTypes(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create inboxes with different message types + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED}, + }) + require.NoError(t, err) + + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // Filter by MEMO_COMMENT - should get 2 + memoCommentType := storepb.InboxMessage_MEMO_COMMENT + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + MessageType: &memoCommentType, + }) + require.NoError(t, err) + require.Len(t, inboxes, 2) + for _, inbox := range inboxes { + require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type) + } + + // Filter by TYPE_UNSPECIFIED - should get 1 + unspecifiedType := storepb.InboxMessage_TYPE_UNSPECIFIED + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + MessageType: &unspecifiedType, + }) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, storepb.InboxMessage_TYPE_UNSPECIFIED, inboxes[0].Message.Type) + + ts.Close() +} + +func TestInboxMessageTypeFilterWithPayload(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create inbox with full payload + activityID := int32(456) + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{ + Type: storepb.InboxMessage_MEMO_COMMENT, + ActivityId: &activityID, + }, + }) + require.NoError(t, err) + + // Create inbox with different type but also has payload + otherActivityID := int32(789) + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{ + Type: storepb.InboxMessage_TYPE_UNSPECIFIED, + ActivityId: &otherActivityID, + }, + }) + require.NoError(t, err) + + // Filter by type should work correctly even with complex JSON payload + memoCommentType := storepb.InboxMessage_MEMO_COMMENT + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + MessageType: &memoCommentType, + }) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, activityID, *inboxes[0].Message.ActivityId) + + ts.Close() +} + +func TestInboxMessageTypeFilterWithStatusAndPagination(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create multiple inboxes with various combinations + for i := 0; i < 5; i++ { + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + } + + // Archive 2 of them + allInboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID}) + require.NoError(t, err) + for i := 0; i < 2; i++ { + _, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: allInboxes[i].ID, Status: store.ARCHIVED}) + require.NoError(t, err) + } + + // Filter by type + status + pagination + memoCommentType := storepb.InboxMessage_MEMO_COMMENT + unreadStatus := store.UNREAD + limit := 2 + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + MessageType: &memoCommentType, + Status: &unreadStatus, + Limit: &limit, + }) + require.NoError(t, err) + require.Len(t, inboxes, 2) + for _, inbox := range inboxes { + require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type) + require.Equal(t, store.UNREAD, inbox.Status) + } + + // Get next page + offset := 2 + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ + ReceiverID: &user.ID, + MessageType: &memoCommentType, + Status: &unreadStatus, + Limit: &limit, + Offset: &offset, + }) + require.NoError(t, err) + require.Len(t, inboxes, 1) // Only 1 remaining (3 unread total, got 2, now 1 left) + + ts.Close() +} + +func TestInboxMultipleReceivers(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user1, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser) + require.NoError(t, err) + + // Create inbox for user1 + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user1.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // Create inbox for user2 + _, err = ts.CreateInbox(ctx, &store.Inbox{ + SenderID: 0, + ReceiverID: user2.ID, + Status: store.UNREAD, + Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT}, + }) + require.NoError(t, err) + + // User1 should only see their inbox + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user1.ID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, user1.ID, inboxes[0].ReceiverID) + + // User2 should only see their inbox + inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user2.ID}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.Equal(t, user2.ID, inboxes[0].ReceiverID) + + ts.Close() +} diff --git a/store/test/instance_setting_test.go b/store/test/instance_setting_test.go index 9ab072753..0f0c4cfb0 100644 --- a/store/test/instance_setting_test.go +++ b/store/test/instance_setting_test.go @@ -11,6 +11,7 @@ import ( ) func TestInstanceSettingV1Store(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) instanceSetting, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ @@ -31,6 +32,7 @@ func TestInstanceSettingV1Store(t *testing.T) { } func TestInstanceSettingGetNonExistent(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -45,6 +47,7 @@ func TestInstanceSettingGetNonExistent(t *testing.T) { } func TestInstanceSettingUpsertUpdate(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -88,6 +91,7 @@ func TestInstanceSettingUpsertUpdate(t *testing.T) { } func TestInstanceSettingBasicSetting(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -116,6 +120,7 @@ func TestInstanceSettingBasicSetting(t *testing.T) { } func TestInstanceSettingGeneralSetting(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -146,6 +151,7 @@ func TestInstanceSettingGeneralSetting(t *testing.T) { } func TestInstanceSettingMemoRelatedSetting(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -179,6 +185,7 @@ func TestInstanceSettingMemoRelatedSetting(t *testing.T) { } func TestInstanceSettingStorageSetting(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -214,6 +221,7 @@ func TestInstanceSettingStorageSetting(t *testing.T) { } func TestInstanceSettingListAll(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -246,3 +254,47 @@ func TestInstanceSettingListAll(t *testing.T) { ts.Close() } + +func TestInstanceSettingEdgeCases(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Case 1: General Setting with special characters and Unicode + specialScript := `` + specialStyle := `body { font-family: "Noto Sans SC", sans-serif; content: "\u2764"; }` + _, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_GENERAL, + Value: &storepb.InstanceSetting_GeneralSetting{ + GeneralSetting: &storepb.InstanceGeneralSetting{ + AdditionalScript: specialScript, + AdditionalStyle: specialStyle, + }, + }, + }) + require.NoError(t, err) + + generalSetting, err := ts.GetInstanceGeneralSetting(ctx) + require.NoError(t, err) + require.Equal(t, specialScript, generalSetting.AdditionalScript) + require.Equal(t, specialStyle, generalSetting.AdditionalStyle) + + // Case 2: Memo Related Setting with Unicode reactions + unicodeReactions := []string{"🐱", "🐶", "🦊", "🦄"} + _, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_MEMO_RELATED, + Value: &storepb.InstanceSetting_MemoRelatedSetting{ + MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{ + ContentLengthLimit: 1000, + Reactions: unicodeReactions, + }, + }, + }) + require.NoError(t, err) + + memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx) + require.NoError(t, err) + require.Equal(t, unicodeReactions, memoSetting.Reactions) + + ts.Close() +} diff --git a/store/test/main_test.go b/store/test/main_test.go index 97a632765..1a1139f19 100644 --- a/store/test/main_test.go +++ b/store/test/main_test.go @@ -13,7 +13,7 @@ func TestMain(m *testing.M) { // If DRIVER is set, run tests for that driver only if os.Getenv("DRIVER") != "" { defer TerminateContainers() - m.Run() + m.Run() //nolint:revive // Exit code is handled by test runner return } diff --git a/store/test/memo_filter_test.go b/store/test/memo_filter_test.go index 9f2520211..572086e79 100644 --- a/store/test/memo_filter_test.go +++ b/store/test/memo_filter_test.go @@ -16,6 +16,7 @@ import ( // ============================================================================= func TestMemoFilterContentContains(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -39,6 +40,7 @@ func TestMemoFilterContentContains(t *testing.T) { } func TestMemoFilterContentSpecialCharacters(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -49,6 +51,7 @@ func TestMemoFilterContentSpecialCharacters(t *testing.T) { } func TestMemoFilterContentUnicode(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -58,12 +61,35 @@ func TestMemoFilterContentUnicode(t *testing.T) { require.Len(t, memos, 1) } +func TestMemoFilterContentCaseSensitivity(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + tc.CreateMemo(NewMemoBuilder("memo-case", tc.User.ID).Content("MixedCase Content")) + + // Exact match + memos := tc.ListWithFilter(`content.contains("MixedCase")`) + require.Len(t, memos, 1) + + // Lowercase match (depends on DB collation, usually case-insensitive in default installs but good to verify behavior) + // SQLite default LIKE is case-insensitive for ASCII. + memosLower := tc.ListWithFilter(`content.contains("mixedcase")`) + // We just verify it doesn't crash; strict case sensitivity expectation depends on DB config. + // For standard Memos setup (SQLite), it's often case-insensitive. + // Let's check if we get a result or not to characterize current behavior. + if len(memosLower) > 0 { + require.Equal(t, "MixedCase Content", memosLower[0].Content) + } +} + // ============================================================================= // Visibility Field Tests // Schema: visibility (string, ==, !=) // ============================================================================= func TestMemoFilterVisibilityEquals(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -83,6 +109,7 @@ func TestMemoFilterVisibilityEquals(t *testing.T) { } func TestMemoFilterVisibilityNotEquals(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -95,6 +122,7 @@ func TestMemoFilterVisibilityNotEquals(t *testing.T) { } func TestMemoFilterVisibilityInList(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -112,6 +140,7 @@ func TestMemoFilterVisibilityInList(t *testing.T) { // ============================================================================= func TestMemoFilterPinnedEquals(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -131,6 +160,7 @@ func TestMemoFilterPinnedEquals(t *testing.T) { } func TestMemoFilterPinnedPredicate(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -149,6 +179,7 @@ func TestMemoFilterPinnedPredicate(t *testing.T) { // ============================================================================= func TestMemoFilterCreatorIdEquals(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -169,6 +200,7 @@ func TestMemoFilterCreatorIdEquals(t *testing.T) { } func TestMemoFilterCreatorIdNotEquals(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -195,6 +227,7 @@ func TestMemoFilterCreatorIdNotEquals(t *testing.T) { // ============================================================================= func TestMemoFilterTagInList(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -213,6 +246,7 @@ func TestMemoFilterTagInList(t *testing.T) { } func TestMemoFilterElementInTags(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -229,6 +263,7 @@ func TestMemoFilterElementInTags(t *testing.T) { } func TestMemoFilterHierarchicalTags(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -241,6 +276,7 @@ func TestMemoFilterHierarchicalTags(t *testing.T) { } func TestMemoFilterEmptyTags(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -257,6 +293,7 @@ func TestMemoFilterEmptyTags(t *testing.T) { // ============================================================================= func TestMemoFilterHasTaskList(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -279,6 +316,7 @@ func TestMemoFilterHasTaskList(t *testing.T) { } func TestMemoFilterHasLink(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -293,6 +331,7 @@ func TestMemoFilterHasLink(t *testing.T) { } func TestMemoFilterHasCode(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -307,6 +346,7 @@ func TestMemoFilterHasCode(t *testing.T) { } func TestMemoFilterHasIncompleteTasks(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -329,6 +369,7 @@ func TestMemoFilterHasIncompleteTasks(t *testing.T) { } func TestMemoFilterCombinedJSONBool(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -367,6 +408,7 @@ func TestMemoFilterCombinedJSONBool(t *testing.T) { // ============================================================================= func TestMemoFilterCreatedTsComparison(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -387,6 +429,7 @@ func TestMemoFilterCreatedTsComparison(t *testing.T) { } func TestMemoFilterCreatedTsWithNow(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -402,6 +445,7 @@ func TestMemoFilterCreatedTsWithNow(t *testing.T) { } func TestMemoFilterCreatedTsArithmetic(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -421,6 +465,7 @@ func TestMemoFilterCreatedTsArithmetic(t *testing.T) { } func TestMemoFilterUpdatedTs(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -444,6 +489,7 @@ func TestMemoFilterUpdatedTs(t *testing.T) { } func TestMemoFilterAllComparisonOperators(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -472,6 +518,7 @@ func TestMemoFilterAllComparisonOperators(t *testing.T) { // ============================================================================= func TestMemoFilterLogicalAnd(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -485,6 +532,7 @@ func TestMemoFilterLogicalAnd(t *testing.T) { } func TestMemoFilterLogicalOr(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -497,6 +545,7 @@ func TestMemoFilterLogicalOr(t *testing.T) { } func TestMemoFilterLogicalNot(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -510,6 +559,7 @@ func TestMemoFilterLogicalNot(t *testing.T) { } func TestMemoFilterNegatedComparison(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -522,6 +572,7 @@ func TestMemoFilterNegatedComparison(t *testing.T) { } func TestMemoFilterComplexLogical(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -545,6 +596,244 @@ func TestMemoFilterComplexLogical(t *testing.T) { // Test: (pinned || tag in ["important"]) && visibility == "PUBLIC" memos = tc.ListWithFilter(`(pinned || tag in ["important"]) && visibility == "PUBLIC"`) require.Len(t, memos, 3) + + // Test: De Morgan's Law ! (A || B) == !A && !B + // ! (pinned || has_task_list) + tc.CreateMemo(NewMemoBuilder("memo-no-props", tc.User.ID).Content("Nothing special")) + memos = tc.ListWithFilter(`!(pinned || has_task_list)`) + require.Len(t, memos, 2) // Unpinned-tagged + Nothing special (pinned-untagged is pinned) +} + +// ============================================================================= +// Tag Comprehension Tests (exists macro) +// Schema: tags (list of strings, supports exists/all macros with predicates) +// ============================================================================= + +func TestMemoFilterTagsExistsStartsWith(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // Create memos with different tags + tc.CreateMemo(NewMemoBuilder("memo-archive1", tc.User.ID). + Content("Archived project memo"). + Tags("archive/project", "done")) + + tc.CreateMemo(NewMemoBuilder("memo-archive2", tc.User.ID). + Content("Archived work memo"). + Tags("archive/work", "old")) + + tc.CreateMemo(NewMemoBuilder("memo-active", tc.User.ID). + Content("Active project memo"). + Tags("project/active", "todo")) + + tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID). + Content("Homelab memo"). + Tags("homelab/memos", "tech")) + + // Test: tags.exists(t, t.startsWith("archive")) - should match archived memos + memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`) + require.Len(t, memos, 2, "Should find 2 archived memos") + for _, memo := range memos { + hasArchiveTag := false + for _, tag := range memo.Payload.Tags { + if len(tag) >= 7 && tag[:7] == "archive" { + hasArchiveTag = true + break + } + } + require.True(t, hasArchiveTag, "Memo should have tag starting with 'archive'") + } + + // Test: !tags.exists(t, t.startsWith("archive")) - should match non-archived memos + memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`) + require.Len(t, memos, 2, "Should find 2 non-archived memos") + + // Test: tags.exists(t, t.startsWith("project")) - should match project memos + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("project"))`) + require.Len(t, memos, 1, "Should find 1 project memo") + + // Test: tags.exists(t, t.startsWith("homelab")) - should match homelab memos + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("homelab"))`) + require.Len(t, memos, 1, "Should find 1 homelab memo") + + // Test: tags.exists(t, t.startsWith("nonexistent")) - should match nothing + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("nonexistent"))`) + require.Len(t, memos, 0, "Should find no memos") +} + +func TestMemoFilterTagsExistsContains(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // Create memos with different tags + tc.CreateMemo(NewMemoBuilder("memo-todo1", tc.User.ID). + Content("Todo task 1"). + Tags("project/todo", "urgent")) + + tc.CreateMemo(NewMemoBuilder("memo-todo2", tc.User.ID). + Content("Todo task 2"). + Tags("work/todo-list", "pending")) + + tc.CreateMemo(NewMemoBuilder("memo-done", tc.User.ID). + Content("Done task"). + Tags("project/completed", "done")) + + // Test: tags.exists(t, t.contains("todo")) - should match todos + memos := tc.ListWithFilter(`tags.exists(t, t.contains("todo"))`) + require.Len(t, memos, 2, "Should find 2 todo memos") + + // Test: tags.exists(t, t.contains("done")) - should match done + memos = tc.ListWithFilter(`tags.exists(t, t.contains("done"))`) + require.Len(t, memos, 1, "Should find 1 done memo") + + // Test: !tags.exists(t, t.contains("todo")) - should exclude todos + memos = tc.ListWithFilter(`!tags.exists(t, t.contains("todo"))`) + require.Len(t, memos, 1, "Should find 1 non-todo memo") +} + +func TestMemoFilterTagsExistsEndsWith(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // Create memos with different tag endings + tc.CreateMemo(NewMemoBuilder("memo-bug", tc.User.ID). + Content("Bug report"). + Tags("project/bug", "critical")) + + tc.CreateMemo(NewMemoBuilder("memo-debug", tc.User.ID). + Content("Debug session"). + Tags("work/debug", "dev")) + + tc.CreateMemo(NewMemoBuilder("memo-feature", tc.User.ID). + Content("New feature"). + Tags("project/feature", "new")) + + // Test: tags.exists(t, t.endsWith("bug")) - should match bug-related tags + memos := tc.ListWithFilter(`tags.exists(t, t.endsWith("bug"))`) + require.Len(t, memos, 2, "Should find 2 bug-related memos") + + // Test: tags.exists(t, t.endsWith("feature")) - should match feature + memos = tc.ListWithFilter(`tags.exists(t, t.endsWith("feature"))`) + require.Len(t, memos, 1, "Should find 1 feature memo") + + // Test: !tags.exists(t, t.endsWith("bug")) - should exclude bug-related + memos = tc.ListWithFilter(`!tags.exists(t, t.endsWith("bug"))`) + require.Len(t, memos, 1, "Should find 1 non-bug memo") +} + +func TestMemoFilterTagsExistsCombinedWithOtherFilters(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // Create memos with tags and other properties + tc.CreateMemo(NewMemoBuilder("memo-archived-old", tc.User.ID). + Content("Old archived memo"). + Tags("archive/old", "done")) + + tc.CreateMemo(NewMemoBuilder("memo-archived-recent", tc.User.ID). + Content("Recent archived memo with TODO"). + Tags("archive/recent", "done")) + + tc.CreateMemo(NewMemoBuilder("memo-active-todo", tc.User.ID). + Content("Active TODO"). + Tags("project/active", "todo")) + + // Test: Combine tag filter with content filter + memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && content.contains("TODO")`) + require.Len(t, memos, 1, "Should find 1 archived memo with TODO in content") + + // Test: OR condition with tag filters + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) || tags.exists(t, t.contains("todo"))`) + require.Len(t, memos, 3, "Should find all memos (archived or with todo tag)") + + // Test: Complex filter - archived but not containing "Recent" + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && !content.contains("Recent")`) + require.Len(t, memos, 1, "Should find 1 old archived memo") +} + +func TestMemoFilterTagsExistsEmptyAndNullCases(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // Create memo with no tags + tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID). + Content("Memo without tags")) + + // Create memo with tags + tc.CreateMemo(NewMemoBuilder("memo-with-tags", tc.User.ID). + Content("Memo with tags"). + Tags("tag1", "tag2")) + + // Test: tags.exists should not match memos without tags + memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("tag"))`) + require.Len(t, memos, 1, "Should only find memo with tags") + + // Test: Negation should match memos without matching tags + memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("tag"))`) + require.Len(t, memos, 1, "Should find memo without matching tags") +} + +func TestMemoFilterIssue5480_ArchiveWorkflow(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // Create a realistic scenario as described in issue #5480 + // User has hierarchical tags and archives memos by prefixing with "archive" + + // Active memos + tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID). + Content("Setting up Memos"). + Tags("homelab/memos", "tech")) + + tc.CreateMemo(NewMemoBuilder("memo-project-alpha", tc.User.ID). + Content("Project Alpha notes"). + Tags("work/project-alpha", "active")) + + // Archived memos (user prefixed tags with "archive") + tc.CreateMemo(NewMemoBuilder("memo-old-homelab", tc.User.ID). + Content("Old homelab setup"). + Tags("archive/homelab/old-server", "done")) + + tc.CreateMemo(NewMemoBuilder("memo-old-project", tc.User.ID). + Content("Old project beta"). + Tags("archive/work/project-beta", "completed")) + + tc.CreateMemo(NewMemoBuilder("memo-archived-personal", tc.User.ID). + Content("Archived personal note"). + Tags("archive/personal/2024", "old")) + + // Test: Filter out ALL archived memos using startsWith + memos := tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`) + require.Len(t, memos, 2, "Should only show active memos (not archived)") + for _, memo := range memos { + for _, tag := range memo.Payload.Tags { + require.NotContains(t, tag, "archive", "Active memos should not have archive prefix") + } + } + + // Test: Show ONLY archived memos + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`) + require.Len(t, memos, 3, "Should find all archived memos") + for _, memo := range memos { + hasArchiveTag := false + for _, tag := range memo.Payload.Tags { + if len(tag) >= 7 && tag[:7] == "archive" { + hasArchiveTag = true + break + } + } + require.True(t, hasArchiveTag, "All returned memos should have archive prefix") + } + + // Test: Filter archived homelab memos specifically + memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive/homelab"))`) + require.Len(t, memos, 1, "Should find only archived homelab memos") } // ============================================================================= @@ -552,6 +841,7 @@ func TestMemoFilterComplexLogical(t *testing.T) { // ============================================================================= func TestMemoFilterMultipleFilters(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -570,6 +860,7 @@ func TestMemoFilterMultipleFilters(t *testing.T) { // ============================================================================= func TestMemoFilterNullPayload(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -581,6 +872,7 @@ func TestMemoFilterNullPayload(t *testing.T) { } func TestMemoFilterNoMatches(t *testing.T) { + t.Parallel() tc := NewMemoFilterTestContext(t) defer tc.Close() @@ -589,3 +881,48 @@ func TestMemoFilterNoMatches(t *testing.T) { memos := tc.ListWithFilter(`content.contains("nonexistent12345")`) require.Len(t, memos, 0) } + +func TestMemoFilterJSONBooleanLogic(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + // 1. Memo with task list (true) and NO link (null) + tc.CreateMemo(NewMemoBuilder("memo-task-only", tc.User.ID). + Content("Task only"). + Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true })) + + // 2. Memo with link (true) and NO task list (null) + tc.CreateMemo(NewMemoBuilder("memo-link-only", tc.User.ID). + Content("Link only"). + Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true })) + + // 3. Memo with both (true) + tc.CreateMemo(NewMemoBuilder("memo-both", tc.User.ID). + Content("Both"). + Property(func(p *storepb.MemoPayload_Property) { + p.HasTaskList = true + p.HasLink = true + })) + + // 4. Memo with neither (null) + tc.CreateMemo(NewMemoBuilder("memo-neither", tc.User.ID).Content("Neither")) + + // Test A: has_task_list || has_link + // Expected: 3 memos (task-only, link-only, both). Neither should be excluded. + // This specifically tests the NULL handling in OR logic (NULL || TRUE should be TRUE) + memos := tc.ListWithFilter(`has_task_list || has_link`) + require.Len(t, memos, 3, "Should find 3 memos with OR logic") + + // Test B: !has_task_list + // Expected: 2 memos (link-only, neither). Memos where has_task_list is NULL or FALSE. + // Note: If NULL is not handled, !NULL is still NULL (false-y in WHERE), so "neither" might be missed depending on logic. + // In our implementation, we want missing fields to behave as false. + memos = tc.ListWithFilter(`!has_task_list`) + require.Len(t, memos, 2, "Should find 2 memos where task list is false or missing") + + // Test C: has_task_list && !has_link + // Expected: 1 memo (task-only). + memos = tc.ListWithFilter(`has_task_list && !has_link`) + require.Len(t, memos, 1, "Should find 1 memo (task only)") +} diff --git a/store/test/memo_relation_test.go b/store/test/memo_relation_test.go index dd05134ef..9cfba6997 100644 --- a/store/test/memo_relation_test.go +++ b/store/test/memo_relation_test.go @@ -10,6 +10,7 @@ import ( ) func TestMemoRelationStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -62,6 +63,7 @@ func TestMemoRelationStore(t *testing.T) { } func TestMemoRelationListByMemoID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -136,6 +138,7 @@ func TestMemoRelationListByMemoID(t *testing.T) { } func TestMemoRelationDelete(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -193,6 +196,7 @@ func TestMemoRelationDelete(t *testing.T) { } func TestMemoRelationDifferentTypes(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -239,3 +243,440 @@ func TestMemoRelationDifferentTypes(t *testing.T) { ts.Close() } + +func TestMemoRelationUpsertSameRelation(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + mainMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "main-memo", + CreatorID: user.ID, + Content: "main memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo", + CreatorID: user.ID, + Content: "related memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create relation + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Upsert the same relation again (should not create duplicate) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Verify only one relation exists + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &mainMemo.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + + ts.Close() +} + +func TestMemoRelationDeleteByType(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + mainMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "main-memo", + CreatorID: user.ID, + Content: "main memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo-1", + CreatorID: user.ID, + Content: "related memo 1 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo-2", + CreatorID: user.ID, + Content: "related memo 2 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create reference relations + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo1.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Create comment relation + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo2.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + // Delete only reference type relations + refType := store.MemoRelationReference + err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ + MemoID: &mainMemo.ID, + Type: &refType, + }) + require.NoError(t, err) + + // Verify only comment relation remains + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &mainMemo.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + require.Equal(t, store.MemoRelationComment, relations[0].Type) + + ts.Close() +} + +func TestMemoRelationDeleteByMemoID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + memo1, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-1", + CreatorID: user.ID, + Content: "memo 1 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memo2, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-2", + CreatorID: user.ID, + Content: "memo 2 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo", + CreatorID: user.ID, + Content: "related memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create relations for both memos + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memo1.ID, + RelatedMemoID: relatedMemo.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memo2.ID, + RelatedMemoID: relatedMemo.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Delete all relations for memo1 + err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ + MemoID: &memo1.ID, + }) + require.NoError(t, err) + + // Verify memo1's relations are gone + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &memo1.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 0) + + // Verify memo2's relations still exist + relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &memo2.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + + ts.Close() +} + +func TestMemoRelationListByRelatedMemoID(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create a memo that will be referenced by others + targetMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "target-memo", + CreatorID: user.ID, + Content: "target memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create memos that reference the target + referrer1, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "referrer-1", + CreatorID: user.ID, + Content: "referrer 1 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + referrer2, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "referrer-2", + CreatorID: user.ID, + Content: "referrer 2 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create relations pointing to target + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: referrer1.ID, + RelatedMemoID: targetMemo.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: referrer2.ID, + RelatedMemoID: targetMemo.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + // List by related memo ID (find all memos that reference the target) + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + RelatedMemoID: &targetMemo.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 2) + + ts.Close() +} + +func TestMemoRelationListCombinedFilters(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + mainMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "main-memo", + CreatorID: user.ID, + Content: "main memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo-1", + CreatorID: user.ID, + Content: "related memo 1 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo-2", + CreatorID: user.ID, + Content: "related memo 2 content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create multiple relations + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo1.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo2.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + // List with MemoID and Type filter + refType := store.MemoRelationReference + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &mainMemo.ID, + Type: &refType, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + require.Equal(t, relatedMemo1.ID, relations[0].RelatedMemoID) + + // List with MemoID, RelatedMemoID, and Type filter + commentType := store.MemoRelationComment + relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &mainMemo.ID, + RelatedMemoID: &relatedMemo2.ID, + Type: &commentType, + }) + require.NoError(t, err) + require.Len(t, relations, 1) + + ts.Close() +} + +func TestMemoRelationListEmpty(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + memo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-no-relations", + CreatorID: user.ID, + Content: "memo with no relations", + Visibility: store.Public, + }) + require.NoError(t, err) + + // List relations for memo with none + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &memo.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 0) + + ts.Close() +} + +func TestMemoRelationBidirectional(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + memoA, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-a", + CreatorID: user.ID, + Content: "memo A content", + Visibility: store.Public, + }) + require.NoError(t, err) + + memoB, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "memo-b", + CreatorID: user.ID, + Content: "memo B content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create relation A -> B + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoA.ID, + RelatedMemoID: memoB.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Create relation B -> A (reverse direction) + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memoB.ID, + RelatedMemoID: memoA.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + // Verify A -> B exists + relationsFromA, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &memoA.ID, + }) + require.NoError(t, err) + require.Len(t, relationsFromA, 1) + require.Equal(t, memoB.ID, relationsFromA[0].RelatedMemoID) + + // Verify B -> A exists + relationsFromB, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &memoB.ID, + }) + require.NoError(t, err) + require.Len(t, relationsFromB, 1) + require.Equal(t, memoA.ID, relationsFromB[0].RelatedMemoID) + + ts.Close() +} + +func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + mainMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "main-memo", + CreatorID: user.ID, + Content: "main memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + // Create multiple memos that all relate to the main memo + for i := 1; i <= 5; i++ { + relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "related-memo-" + string(rune('0'+i)), + CreatorID: user.ID, + Content: "related memo content", + Visibility: store.Public, + }) + require.NoError(t, err) + + _, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: mainMemo.ID, + RelatedMemoID: relatedMemo.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + } + + // Verify all 5 relations exist + relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ + MemoID: &mainMemo.ID, + }) + require.NoError(t, err) + require.Len(t, relations, 5) + + ts.Close() +} diff --git a/store/test/memo_test.go b/store/test/memo_test.go index bcc368a73..78d9ddce0 100644 --- a/store/test/memo_test.go +++ b/store/test/memo_test.go @@ -13,6 +13,7 @@ import ( ) func TestMemoStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -64,6 +65,7 @@ func TestMemoStore(t *testing.T) { } func TestMemoListByTags(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -96,6 +98,7 @@ func TestMemoListByTags(t *testing.T) { } func TestDeleteMemoStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -117,6 +120,7 @@ func TestDeleteMemoStore(t *testing.T) { } func TestMemoGetByID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -147,6 +151,7 @@ func TestMemoGetByID(t *testing.T) { } func TestMemoGetByUID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -177,6 +182,7 @@ func TestMemoGetByUID(t *testing.T) { } func TestMemoListByVisibility(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -239,6 +245,7 @@ func TestMemoListByVisibility(t *testing.T) { } func TestMemoListWithPagination(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -274,6 +281,7 @@ func TestMemoListWithPagination(t *testing.T) { } func TestMemoUpdatePinned(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -317,6 +325,7 @@ func TestMemoUpdatePinned(t *testing.T) { } func TestMemoUpdateVisibility(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -359,6 +368,7 @@ func TestMemoUpdateVisibility(t *testing.T) { } func TestMemoInvalidUID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -378,6 +388,7 @@ func TestMemoInvalidUID(t *testing.T) { } func TestMemoWithPayload(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) diff --git a/store/test/migrator_test.go b/store/test/migrator_test.go index a76f27ee1..3eb541381 100644 --- a/store/test/migrator_test.go +++ b/store/test/migrator_test.go @@ -2,16 +2,199 @@ package test import ( "context" + "fmt" + "os" "testing" + "time" "github.com/stretchr/testify/require" + + "github.com/usememos/memos/store" ) -func TestGetCurrentSchemaVersion(t *testing.T) { +// TestFreshInstall verifies that LATEST.sql applies correctly on a fresh database. +// This is essentially what NewTestingStore already does, but we make it explicit. +func TestFreshInstall(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // NewTestingStore creates a fresh database and runs Migrate() + // which applies LATEST.sql for uninitialized databases + ts := NewTestingStore(ctx, t) + + // Verify migration completed successfully + currentSchemaVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + require.NotEmpty(t, currentSchemaVersion, "schema version should be set after fresh install") + + // Verify we can read instance settings (basic sanity check) + instanceSetting, err := ts.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + require.Equal(t, currentSchemaVersion, instanceSetting.SchemaVersion) +} + +// TestMigrationReRun verifies that re-running the migration on an already +// migrated database does not fail or cause issues. This simulates a +// scenario where the server is restarted. +func TestMigrationReRun(t *testing.T) { + t.Parallel() + + ctx := context.Background() + // Use the shared testing store which already runs migrations on init + ts := NewTestingStore(ctx, t) + + // Get current version + initialVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + + // Manually trigger migration again + err = ts.Migrate(ctx) + require.NoError(t, err, "re-running migration should not fail") + + // Verify version hasn't changed (or at least is valid) + finalVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + require.Equal(t, initialVersion, finalVersion, "version should match after re-run") +} + +// TestMigrationWithData verifies that migration preserves data integrity. +// Creates data, then re-runs migration and verifies data is still accessible. +func TestMigrationWithData(t *testing.T) { + t.Parallel() + ctx := context.Background() ts := NewTestingStore(ctx, t) - currentSchemaVersion, err := ts.GetCurrentSchemaVersion() - require.NoError(t, err) - require.Equal(t, "0.25.1", currentSchemaVersion) + // Create a user and memo before re-running migration + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err, "should create user") + + originalMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "migration-data-test", + CreatorID: user.ID, + Content: "Data before migration re-run", + Visibility: store.Public, + }) + require.NoError(t, err, "should create memo") + + // Re-run migration + err = ts.Migrate(ctx) + require.NoError(t, err, "re-running migration should not fail") + + // Verify data is still accessible + memo, err := ts.GetMemo(ctx, &store.FindMemo{UID: &originalMemo.UID}) + require.NoError(t, err, "should retrieve memo after migration") + require.Equal(t, "Data before migration re-run", memo.Content, "memo content should be preserved") +} + +// TestMigrationMultipleReRuns verifies that migration is idempotent +// even when run multiple times in succession. +func TestMigrationMultipleReRuns(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + // Get initial version + initialVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + + // Run migration multiple times + for i := 0; i < 3; i++ { + err = ts.Migrate(ctx) + require.NoError(t, err, "migration run %d should not fail", i+1) + } + + // Verify version is still correct + finalVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + require.Equal(t, initialVersion, finalVersion, "version should remain unchanged after multiple re-runs") +} + +// TestMigrationFromStableVersion verifies that upgrading from a stable Memos version +// to the current version works correctly. This is the critical upgrade path test. +// +// Test flow: +// 1. Start a stable Memos container to create a database with the old schema +// 2. Stop the container and wait for cleanup +// 3. Use the store directly to run migration with current code +// 4. Verify the migration succeeded and data can be written +// +// Note: This test is skipped when running with -race flag because testcontainers +// has known race conditions in its reaper code that are outside our control. +func TestMigrationFromStableVersion(t *testing.T) { + // Skip for non-SQLite drivers (simplifies the test) + if getDriverFromEnv() != "sqlite" { + t.Skip("skipping upgrade test for non-sqlite driver") + } + + // Skip if explicitly disabled (e.g., in environments without Docker) + if os.Getenv("SKIP_CONTAINER_TESTS") == "1" { + t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)") + } + + ctx := context.Background() + dataDir := t.TempDir() + + // 1. Start stable Memos container to create database with old schema + cfg := MemosContainerConfig{ + Driver: "sqlite", + DataDir: dataDir, + Version: StableMemosVersion, + } + + t.Logf("Starting Memos %s container to create old-schema database...", cfg.Version) + container, err := StartMemosContainer(ctx, cfg) + require.NoError(t, err, "failed to start stable memos container") + + // Wait for the container to fully initialize the database + time.Sleep(10 * time.Second) + + // Stop the container gracefully + t.Log("Stopping stable Memos container...") + err = container.Terminate(ctx) + require.NoError(t, err, "failed to stop memos container") + + // Wait for file handles to be released + time.Sleep(2 * time.Second) + + // 2. Connect to the database directly and run migration with current code + dsn := fmt.Sprintf("%s/memos_prod.db", dataDir) + t.Logf("Connecting to database at %s...", dsn) + + ts := NewTestingStoreWithDSN(ctx, t, "sqlite", dsn) + + // Get the schema version before migration + oldSetting, err := ts.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + t.Logf("Old schema version: %s", oldSetting.SchemaVersion) + + // 3. Run migration with current code + t.Log("Running migration with current code...") + err = ts.Migrate(ctx) + require.NoError(t, err, "migration from stable version should succeed") + + // 4. Verify migration succeeded + newVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + t.Logf("New schema version: %s", newVersion) + + newSetting, err := ts.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + require.Equal(t, newVersion, newSetting.SchemaVersion, "schema version should be updated") + + // Verify we can write data to the migrated database + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err, "should create user after migration") + + memo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "post-upgrade-memo", + CreatorID: user.ID, + Content: "Content after upgrade from stable", + Visibility: store.Public, + }) + require.NoError(t, err, "should create memo after migration") + require.Equal(t, "Content after upgrade from stable", memo.Content) + + t.Logf("Migration successful: %s -> %s", oldSetting.SchemaVersion, newVersion) } diff --git a/store/test/reaction_test.go b/store/test/reaction_test.go index 986eed9b2..6f8b220ac 100644 --- a/store/test/reaction_test.go +++ b/store/test/reaction_test.go @@ -10,6 +10,7 @@ import ( ) func TestReactionStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -67,6 +68,7 @@ func TestReactionStore(t *testing.T) { } func TestReactionListByCreatorID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -113,6 +115,7 @@ func TestReactionListByCreatorID(t *testing.T) { } func TestReactionMultipleContentIDs(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -148,6 +151,7 @@ func TestReactionMultipleContentIDs(t *testing.T) { } func TestReactionUpsertDifferentTypes(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) diff --git a/store/test/store.go b/store/test/store.go index 3c7abf20a..c6a9b32db 100644 --- a/store/test/store.go +++ b/store/test/store.go @@ -37,6 +37,27 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { return store } +// NewTestingStoreWithDSN creates a testing store connected to a specific DSN. +// This is useful for testing migrations on existing data. +func NewTestingStoreWithDSN(_ context.Context, t *testing.T, driver, dsn string) *store.Store { + profile := &profile.Profile{ + Port: getUnusedPort(), + Data: t.TempDir(), // Dummy dir, DSN matters + DSN: dsn, + Driver: driver, + Version: version.GetCurrentVersion(), + } + dbDriver, err := db.NewDBDriver(profile) + if err != nil { + t.Fatalf("failed to create db driver: %v", err) + } + + store := store.New(dbDriver, profile) + // Do not run Migrate() automatically, as we might be testing pre-migration state + // or want to run it manually. + return store +} + func getUnusedPort() int { // Get a random unused port listener, err := net.Listen("tcp", "localhost:0") @@ -73,12 +94,11 @@ func getTestingProfileForDriver(t *testing.T, driver string) *profile.Profile { } return &profile.Profile{ - Mode: mode, Port: port, Data: dir, DSN: dsn, Driver: driver, - Version: version.GetCurrentVersion(mode), + Version: version.GetCurrentVersion(), } } diff --git a/store/test/user_setting_test.go b/store/test/user_setting_test.go index 99a40f775..cf66afa22 100644 --- a/store/test/user_setting_test.go +++ b/store/test/user_setting_test.go @@ -2,15 +2,18 @@ package test import ( "context" + "strings" "testing" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) func TestUserSettingStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -28,6 +31,7 @@ func TestUserSettingStore(t *testing.T) { } func TestUserSettingGetByUserID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -62,6 +66,7 @@ func TestUserSettingGetByUserID(t *testing.T) { } func TestUserSettingUpsertUpdate(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -100,6 +105,7 @@ func TestUserSettingUpsertUpdate(t *testing.T) { } func TestUserSettingRefreshTokens(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -160,6 +166,7 @@ func TestUserSettingRefreshTokens(t *testing.T) { } func TestUserSettingPersonalAccessTokens(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -211,6 +218,7 @@ func TestUserSettingPersonalAccessTokens(t *testing.T) { } func TestUserSettingWebhooks(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -277,6 +285,7 @@ func TestUserSettingWebhooks(t *testing.T) { } func TestUserSettingShortcuts(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -308,3 +317,559 @@ func TestUserSettingShortcuts(t *testing.T) { ts.Close() } + +func TestUserSettingGetUserByPATHash(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create a PAT with a known hash + patHash := "test-pat-hash-12345" + pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-test-1", + TokenHash: patHash, + Description: "Test PAT for lookup", + } + err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat) + require.NoError(t, err) + + // Lookup user by PAT hash + result, err := ts.GetUserByPATHash(ctx, patHash) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.NotNil(t, result.User) + require.Equal(t, user.Username, result.User.Username) + require.NotNil(t, result.PAT) + require.Equal(t, "pat-test-1", result.PAT.TokenId) + require.Equal(t, "Test PAT for lookup", result.PAT.Description) + + ts.Close() +} + +func TestUserSettingGetUserByPATHashNotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + _, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Lookup non-existent PAT hash + result, err := ts.GetUserByPATHash(ctx, "non-existent-hash") + require.Error(t, err) + require.Nil(t, result) + + ts.Close() +} + +func TestUserSettingGetUserByPATHashMultipleUsers(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user1, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser) + require.NoError(t, err) + + // Create PATs for both users + pat1Hash := "user1-pat-hash" + err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-user1", + TokenHash: pat1Hash, + Description: "User 1 PAT", + }) + require.NoError(t, err) + + pat2Hash := "user2-pat-hash" + err = ts.AddUserPersonalAccessToken(ctx, user2.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-user2", + TokenHash: pat2Hash, + Description: "User 2 PAT", + }) + require.NoError(t, err) + + // Lookup user1's PAT + result1, err := ts.GetUserByPATHash(ctx, pat1Hash) + require.NoError(t, err) + require.Equal(t, user1.ID, result1.UserID) + require.Equal(t, user1.Username, result1.User.Username) + + // Lookup user2's PAT + result2, err := ts.GetUserByPATHash(ctx, pat2Hash) + require.NoError(t, err) + require.Equal(t, user2.ID, result2.UserID) + require.Equal(t, user2.Username, result2.User.Username) + + ts.Close() +} + +func TestUserSettingGetUserByPATHashMultiplePATsSameUser(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create multiple PATs for the same user + pat1Hash := "first-pat-hash" + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-1", + TokenHash: pat1Hash, + Description: "First PAT", + }) + require.NoError(t, err) + + pat2Hash := "second-pat-hash" + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-2", + TokenHash: pat2Hash, + Description: "Second PAT", + }) + require.NoError(t, err) + + // Both PATs should resolve to the same user + result1, err := ts.GetUserByPATHash(ctx, pat1Hash) + require.NoError(t, err) + require.Equal(t, user.ID, result1.UserID) + require.Equal(t, "pat-1", result1.PAT.TokenId) + + result2, err := ts.GetUserByPATHash(ctx, pat2Hash) + require.NoError(t, err) + require.Equal(t, user.ID, result2.UserID) + require.Equal(t, "pat-2", result2.PAT.TokenId) + + ts.Close() +} + +func TestUserSettingUpdatePATLastUsed(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create a PAT + patHash := "pat-hash-for-update" + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-update-test", + TokenHash: patHash, + Description: "PAT for update test", + }) + require.NoError(t, err) + + // Update last used timestamp + now := timestamppb.Now() + err = ts.UpdatePATLastUsed(ctx, user.ID, "pat-update-test", now) + require.NoError(t, err) + + // Verify the update + pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID) + require.NoError(t, err) + require.Len(t, pats, 1) + require.NotNil(t, pats[0].LastUsedAt) + + ts.Close() +} + +func TestUserSettingGetUserByPATHashWithExpiredToken(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create a PAT with expiration info + patHash := "pat-hash-with-expiry" + expiresAt := timestamppb.Now() + pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-expiry-test", + TokenHash: patHash, + Description: "PAT with expiry", + ExpiresAt: expiresAt, + } + err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat) + require.NoError(t, err) + + // Should still be able to look up by hash (expiry check is done at auth level) + result, err := ts.GetUserByPATHash(ctx, patHash) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.NotNil(t, result.PAT.ExpiresAt) + + ts.Close() +} + +func TestUserSettingGetUserByPATHashAfterRemoval(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create a PAT + patHash := "pat-hash-to-remove" + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-remove-test", + TokenHash: patHash, + Description: "PAT to be removed", + }) + require.NoError(t, err) + + // Verify it exists + result, err := ts.GetUserByPATHash(ctx, patHash) + require.NoError(t, err) + require.NotNil(t, result) + + // Remove the PAT + err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-remove-test") + require.NoError(t, err) + + // Should no longer be found + result, err = ts.GetUserByPATHash(ctx, patHash) + require.Error(t, err) + require.Nil(t, result) + + ts.Close() +} + +func TestUserSettingGetUserByPATHashSpecialCharacters(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create PATs with special characters in hash (simulating real hash values) + testCases := []struct { + tokenID string + tokenHash string + }{ + {"pat-special-1", "abc123+/=XYZ"}, + {"pat-special-2", "sha256:abcdef1234567890"}, + {"pat-special-3", "$2a$10$N9qo8uLOickgx2ZMRZoMy"}, + } + + for _, tc := range testCases { + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: tc.tokenID, + TokenHash: tc.tokenHash, + Description: "PAT with special chars", + }) + require.NoError(t, err) + + // Verify lookup works with special characters + result, err := ts.GetUserByPATHash(ctx, tc.tokenHash) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tc.tokenID, result.PAT.TokenId) + } + + ts.Close() +} + +func TestUserSettingGetUserByPATHashLargeTokenCount(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create many PATs for the same user + tokenCount := 10 + hashes := make([]string, tokenCount) + for i := 0; i < tokenCount; i++ { + hashes[i] = "pat-hash-" + string(rune('A'+i)) + "-large-test" + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-large-" + string(rune('A'+i)), + TokenHash: hashes[i], + Description: "PAT for large count test", + }) + require.NoError(t, err) + } + + // Verify each hash can be looked up + for i, hash := range hashes { + result, err := ts.GetUserByPATHash(ctx, hash) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.Equal(t, "pat-large-"+string(rune('A'+i)), result.PAT.TokenId) + } + + ts.Close() +} + +func TestUserSettingMultipleSettingTypes(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Create GENERAL setting + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_GENERAL, + Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "ja"}}, + }) + require.NoError(t, err) + + // Create SHORTCUTS setting + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{ + {Id: "s1", Title: "Shortcut 1"}, + }, + }}, + }) + require.NoError(t, err) + + // Add a PAT + err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-multi", + TokenHash: "hash-multi", + }) + require.NoError(t, err) + + // List all settings for user + settings, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID}) + require.NoError(t, err) + require.Len(t, settings, 3) + + // Verify each setting type + generalSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_GENERAL}) + require.NoError(t, err) + require.Equal(t, "ja", generalSetting.GetGeneral().Locale) + + shortcutsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_SHORTCUTS}) + require.NoError(t, err) + require.Len(t, shortcutsSetting.GetShortcuts().Shortcuts, 1) + + patsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS}) + require.NoError(t, err) + require.Len(t, patsSetting.GetPersonalAccessTokens().Tokens, 1) + + ts.Close() +} + +func TestUserSettingShortcutsEdgeCases(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Case 1: Special characters in Filter and Title + // Includes quotes, backslashes, newlines, and other JSON-sensitive characters + specialCharsFilter := `tag in ["work", "project"] && content.contains("urgent")` + specialCharsTitle := `Work "Urgent" \ Notes` + shortcuts := &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{ + {Id: "s1", Title: specialCharsTitle, Filter: specialCharsFilter}, + }, + } + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts}, + }) + require.NoError(t, err) + + setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &user.ID, + Key: storepb.UserSetting_SHORTCUTS, + }) + require.NoError(t, err) + require.NotNil(t, setting) + require.Len(t, setting.GetShortcuts().Shortcuts, 1) + require.Equal(t, specialCharsTitle, setting.GetShortcuts().Shortcuts[0].Title) + require.Equal(t, specialCharsFilter, setting.GetShortcuts().Shortcuts[0].Filter) + + // Case 2: Unicode characters + unicodeFilter := `tag in ["你好", "世界"]` + unicodeTitle := `My 🚀 Shortcuts` + shortcuts = &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{ + {Id: "s2", Title: unicodeTitle, Filter: unicodeFilter}, + }, + } + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts}, + }) + require.NoError(t, err) + + setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &user.ID, + Key: storepb.UserSetting_SHORTCUTS, + }) + require.NoError(t, err) + require.NotNil(t, setting) + require.Len(t, setting.GetShortcuts().Shortcuts, 1) + require.Equal(t, unicodeTitle, setting.GetShortcuts().Shortcuts[0].Title) + require.Equal(t, unicodeFilter, setting.GetShortcuts().Shortcuts[0].Filter) + + // Case 3: Empty shortcuts list + // Should allow saving an empty list (clearing shortcuts) + shortcuts = &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{}, + } + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts}, + }) + require.NoError(t, err) + + setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &user.ID, + Key: storepb.UserSetting_SHORTCUTS, + }) + require.NoError(t, err) + require.NotNil(t, setting) + require.NotNil(t, setting.GetShortcuts()) + require.Len(t, setting.GetShortcuts().Shortcuts, 0) + + // Case 4: Large filter string + // Test reasonable large string handling (e.g. 4KB) + largeFilter := strings.Repeat("tag:long_tag_name ", 200) + shortcuts = &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{ + {Id: "s3", Title: "Large Filter", Filter: largeFilter}, + }, + } + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts}, + }) + require.NoError(t, err) + + setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &user.ID, + Key: storepb.UserSetting_SHORTCUTS, + }) + require.NoError(t, err) + require.NotNil(t, setting) + require.Equal(t, largeFilter, setting.GetShortcuts().Shortcuts[0].Filter) + + ts.Close() +} + +func TestUserSettingShortcutsPartialUpdate(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Initial set + shortcuts := &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{ + {Id: "s1", Title: "Note 1", Filter: "tag:1"}, + {Id: "s2", Title: "Note 2", Filter: "tag:2"}, + }, + } + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts}, + }) + require.NoError(t, err) + + // Update by replacing the whole list (Store Upsert replaces the value for the key) + // We want to verify that we can "update" a single item by sending the modified list + updatedShortcuts := &storepb.ShortcutsUserSetting{ + Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{ + {Id: "s1", Title: "Note 1 Updated", Filter: "tag:1_updated"}, + {Id: "s2", Title: "Note 2", Filter: "tag:2"}, + {Id: "s3", Title: "Note 3", Filter: "tag:3"}, // Add new one + }, + } + _, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSetting_SHORTCUTS, + Value: &storepb.UserSetting_Shortcuts{Shortcuts: updatedShortcuts}, + }) + require.NoError(t, err) + + setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &user.ID, + Key: storepb.UserSetting_SHORTCUTS, + }) + require.NoError(t, err) + require.NotNil(t, setting) + require.Len(t, setting.GetShortcuts().Shortcuts, 3) + + // Verify updates + for _, s := range setting.GetShortcuts().Shortcuts { + if s.Id == "s1" { + require.Equal(t, "Note 1 Updated", s.Title) + require.Equal(t, "tag:1_updated", s.Filter) + } else if s.Id == "s2" { + require.Equal(t, "Note 2", s.Title) + } else if s.Id == "s3" { + require.Equal(t, "Note 3", s.Title) + } + } + + ts.Close() +} + +func TestUserSettingJSONFieldsEdgeCases(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + // Case 1: Webhook with special characters and Unicode in Title and URL + specialWebhook := &storepb.WebhooksUserSetting_Webhook{ + Id: "wh-special", + Title: `My "Special" & 🚀`, + Url: "https://example.com/hook?query=你好¶m=\"value\"", + } + err = ts.AddUserWebhook(ctx, user.ID, specialWebhook) + require.NoError(t, err) + + webhooks, err := ts.GetUserWebhooks(ctx, user.ID) + require.NoError(t, err) + require.Len(t, webhooks, 1) + require.Equal(t, specialWebhook.Title, webhooks[0].Title) + require.Equal(t, specialWebhook.Url, webhooks[0].Url) + + // Case 2: PAT with special description + specialPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{ + TokenId: "pat-special", + TokenHash: "hash-special", + Description: "Token for 'CLI' \n & \"API\" \t with unicode 🔑", + } + err = ts.AddUserPersonalAccessToken(ctx, user.ID, specialPAT) + require.NoError(t, err) + + pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID) + require.NoError(t, err) + require.Len(t, pats, 1) + require.Equal(t, specialPAT.Description, pats[0].Description) + + // Case 3: Refresh Token with special description + specialRefreshToken := &storepb.RefreshTokensUserSetting_RefreshToken{ + TokenId: "rt-special", + Description: "Browser: Firefox (Nightly) / OS: Linux 🐧", + } + err = ts.AddUserRefreshToken(ctx, user.ID, specialRefreshToken) + require.NoError(t, err) + + tokens, err := ts.GetUserRefreshTokens(ctx, user.ID) + require.NoError(t, err) + require.Len(t, tokens, 1) + require.Equal(t, specialRefreshToken.Description, tokens[0].Description) + + ts.Close() +} diff --git a/store/test/user_test.go b/store/test/user_test.go index 2ffa202fd..1b71af72e 100644 --- a/store/test/user_test.go +++ b/store/test/user_test.go @@ -12,6 +12,7 @@ import ( ) func TestUserStore(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) @@ -19,7 +20,7 @@ func TestUserStore(t *testing.T) { users, err := ts.ListUsers(ctx, &store.FindUser{}) require.NoError(t, err) require.Equal(t, 1, len(users)) - require.Equal(t, store.RoleHost, users[0].Role) + require.Equal(t, store.RoleAdmin, users[0].Role) require.Equal(t, user, users[0]) userPatchNickname := "test_nickname_2" userPatch := &store.UpdateUser{ @@ -40,6 +41,7 @@ func TestUserStore(t *testing.T) { } func TestUserGetByID(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -71,6 +73,7 @@ func TestUserGetByID(t *testing.T) { } func TestUserGetByUsername(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -93,6 +96,7 @@ func TestUserGetByUsername(t *testing.T) { } func TestUserListByRole(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -100,7 +104,7 @@ func TestUserListByRole(t *testing.T) { _, err := createTestingHostUser(ctx, ts) require.NoError(t, err) - adminUser, err := createTestingUserWithRole(ctx, ts, "admin_user", store.RoleAdmin) + _, err = createTestingUserWithRole(ctx, ts, "admin_user", store.RoleAdmin) require.NoError(t, err) regularUser, err := createTestingUserWithRole(ctx, ts, "regular_user", store.RoleUser) @@ -111,19 +115,11 @@ func TestUserListByRole(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, len(allUsers)) - // List only HOST users - hostRole := store.RoleHost - hostUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &hostRole}) - require.NoError(t, err) - require.Equal(t, 1, len(hostUsers)) - require.Equal(t, store.RoleHost, hostUsers[0].Role) - // List only ADMIN users adminRole := store.RoleAdmin - adminUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &adminRole}) + adminOnlyUsers, err := ts.ListUsers(ctx, &store.FindUser{Role: &adminRole}) require.NoError(t, err) - require.Equal(t, 1, len(adminUsers)) - require.Equal(t, adminUser.ID, adminUsers[0].ID) + require.Equal(t, 2, len(adminOnlyUsers)) // List only USER role users userRole := store.RoleUser @@ -136,6 +132,7 @@ func TestUserListByRole(t *testing.T) { } func TestUserUpdateRowStatus(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -170,6 +167,7 @@ func TestUserUpdateRowStatus(t *testing.T) { } func TestUserUpdateAllFields(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -213,6 +211,7 @@ func TestUserUpdateAllFields(t *testing.T) { } func TestUserListWithLimit(t *testing.T) { + t.Parallel() ctx := context.Background() ts := NewTestingStore(ctx, t) @@ -220,7 +219,7 @@ func TestUserListWithLimit(t *testing.T) { for i := 0; i < 5; i++ { role := store.RoleUser if i == 0 { - role = store.RoleHost + role = store.RoleAdmin } _, err := createTestingUserWithRole(ctx, ts, fmt.Sprintf("user%d", i), role) require.NoError(t, err) @@ -236,7 +235,7 @@ func TestUserListWithLimit(t *testing.T) { } func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) { - return createTestingUserWithRole(ctx, ts, "test", store.RoleHost) + return createTestingUserWithRole(ctx, ts, "test", store.RoleAdmin) } func createTestingUserWithRole(ctx context.Context, ts *store.Store, username string, role store.Role) (*store.User, error) { diff --git a/store/user.go b/store/user.go index c07c5c3ee..8fb149539 100644 --- a/store/user.go +++ b/store/user.go @@ -8,8 +8,6 @@ import ( type Role string const ( - // RoleHost is the HOST role. - RoleHost Role = "HOST" // RoleAdmin is the ADMIN role. RoleAdmin Role = "ADMIN" // RoleUser is the USER role. @@ -18,8 +16,6 @@ const ( func (e Role) String() string { switch e { - case RoleHost: - return "HOST" case RoleAdmin: return "ADMIN" default: diff --git a/web/src/App.tsx b/web/src/App.tsx index d8e0f949b..0acba201d 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -20,12 +20,12 @@ const App = () => { cleanupExpiredOAuthState(); }, []); - // Redirect to sign up page if no instance owner + // Redirect to sign up page if instance not initialized (no admin account exists yet) useEffect(() => { - if (!instanceProfile.owner) { + if (!instanceProfile.admin) { navigateTo("/auth/signup"); } - }, [instanceProfile.owner, navigateTo]); + }, [instanceProfile.admin, navigateTo]); useEffect(() => { if (instanceGeneralSetting.additionalStyle) { diff --git a/web/src/components/ActivityCalendar/CalendarCell.tsx b/web/src/components/ActivityCalendar/CalendarCell.tsx index 48026e460..c1551b584 100644 --- a/web/src/components/ActivityCalendar/CalendarCell.tsx +++ b/web/src/components/ActivityCalendar/CalendarCell.tsx @@ -26,7 +26,7 @@ export const CalendarCell = memo((props: CalendarCellProps) => { const smallExtraClasses = size === "small" ? `${SMALL_CELL_SIZE.dimensions} min-h-0` : ""; const baseClasses = cn( - "aspect-square w-full flex items-center justify-center text-center transition-all duration-200 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring/60 focus-visible:ring-offset-2 select-none", + "aspect-square w-full flex items-center justify-center text-center transition-all duration-150 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring/40 focus-visible:ring-offset-2 select-none border border-border/10 bg-muted/20", sizeConfig.font, sizeConfig.borderRadius, smallExtraClasses, @@ -35,7 +35,7 @@ export const CalendarCell = memo((props: CalendarCellProps) => { const ariaLabel = day.isSelected ? `${tooltipText} (selected)` : tooltipText; if (!day.isCurrentMonth) { - return
{day.label}
; + return
{day.label}
; } const intensityClass = getCellIntensityClass(day, maxCount); @@ -45,7 +45,7 @@ export const CalendarCell = memo((props: CalendarCellProps) => { intensityClass, day.isToday && "ring-2 ring-primary/30 ring-offset-1 font-semibold z-10", day.isSelected && "ring-2 ring-primary ring-offset-1 font-bold z-10", - isInteractive ? "cursor-pointer hover:scale-110 hover:shadow-md hover:z-20" : "cursor-default", + isInteractive ? "cursor-pointer hover:bg-muted/40 hover:border-border/30" : "cursor-default", ); const button = ( diff --git a/web/src/components/ActivityCalendar/MonthCalendar.tsx b/web/src/components/ActivityCalendar/MonthCalendar.tsx index b1a9a973d..e57a6a127 100644 --- a/web/src/components/ActivityCalendar/MonthCalendar.tsx +++ b/web/src/components/ActivityCalendar/MonthCalendar.tsx @@ -1,21 +1,43 @@ -import { memo } from "react"; +import { memo, useMemo } from "react"; import { useInstance } from "@/contexts/InstanceContext"; import { cn } from "@/lib/utils"; import { useTranslate } from "@/utils/i18n"; import { CalendarCell } from "./CalendarCell"; -import { DEFAULT_CELL_SIZE, SMALL_CELL_SIZE } from "./constants"; import { useTodayDate, useWeekdayLabels } from "./hooks"; -import type { MonthCalendarProps } from "./types"; +import type { CalendarSize, MonthCalendarProps } from "./types"; import { useCalendarMatrix } from "./useCalendar"; import { getTooltipText } from "./utils"; +const GRID_STYLES: Record = { + small: { gap: "gap-1.5", headerText: "text-[10px]" }, + default: { gap: "gap-2", headerText: "text-xs" }, +}; + +interface WeekdayHeaderProps { + weekDays: string[]; + size: CalendarSize; +} + +const WeekdayHeader = memo(({ weekDays, size }: WeekdayHeaderProps) => ( +
+ {weekDays.map((label, index) => ( +
+ {label} +
+ ))} +
+)); +WeekdayHeader.displayName = "WeekdayHeader"; + export const MonthCalendar = memo((props: MonthCalendarProps) => { const { month, data, maxCount, size = "default", onClick, className } = props; const t = useTranslate(); const { generalSetting } = useInstance(); - - const weekStartDayOffset = generalSetting.weekStartDayOffset; - const today = useTodayDate(); const weekDays = useWeekdayLabels(); @@ -23,41 +45,29 @@ export const MonthCalendar = memo((props: MonthCalendarProps) => { month, data, weekDays, - weekStartDayOffset, + weekStartDayOffset: generalSetting.weekStartDayOffset, today, selectedDate: "", }); - const sizeConfig = size === "small" ? SMALL_CELL_SIZE : DEFAULT_CELL_SIZE; + const flatDays = useMemo(() => weeks.flatMap((week) => week.days), [weeks]); return ( -
-
- {rotatedWeekDays.map((label, index) => ( -
- {label} -
+
+ + +
+ {flatDays.map((day) => ( + ))}
- -
- {weeks.map((week, weekIndex) => - week.days.map((day, dayIndex) => { - const tooltipText = getTooltipText(day.count, day.date, t); - - return ( - - ); - }), - )} -
); }); diff --git a/web/src/components/ActivityCalendar/YearCalendar.tsx b/web/src/components/ActivityCalendar/YearCalendar.tsx index aa875aa82..38ace4f48 100644 --- a/web/src/components/ActivityCalendar/YearCalendar.tsx +++ b/web/src/components/ActivityCalendar/YearCalendar.tsx @@ -1,21 +1,90 @@ import { ChevronLeftIcon, ChevronRightIcon } from "lucide-react"; -import { useMemo } from "react"; -import { - calculateYearMaxCount, - filterDataByYear, - generateMonthsForYear, - getMonthLabel, - MonthCalendar, -} from "@/components/ActivityCalendar"; +import { memo, useMemo } from "react"; import { Button } from "@/components/ui/button"; import { TooltipProvider } from "@/components/ui/tooltip"; import { cn } from "@/lib/utils"; import { useTranslate } from "@/utils/i18n"; import { getMaxYear, MIN_YEAR } from "./constants"; +import { MonthCalendar } from "./MonthCalendar"; import type { YearCalendarProps } from "./types"; +import { calculateYearMaxCount, filterDataByYear, generateMonthsForYear, getMonthLabel } from "./utils"; -export const YearCalendar = ({ selectedYear, data, onYearChange, onDateClick, className }: YearCalendarProps) => { +interface YearNavigationProps { + selectedYear: number; + currentYear: number; + onPrev: () => void; + onNext: () => void; + onToday: () => void; + canGoPrev: boolean; + canGoNext: boolean; +} + +const YearNavigation = memo(({ selectedYear, currentYear, onPrev, onNext, onToday, canGoPrev, canGoNext }: YearNavigationProps) => { const t = useTranslate(); + const isCurrentYear = selectedYear === currentYear; + + return ( +
+

{selectedYear}

+ + +
+ ); +}); +YearNavigation.displayName = "YearNavigation"; + +interface MonthCardProps { + month: string; + data: Record; + maxCount: number; + onDateClick: (date: string) => void; +} + +const MonthCard = memo(({ month, data, maxCount, onDateClick }: MonthCardProps) => ( +
+
{getMonthLabel(month)}
+ +
+)); +MonthCard.displayName = "MonthCard"; + +export const YearCalendar = memo(({ selectedYear, data, onYearChange, onDateClick, className }: YearCalendarProps) => { const currentYear = useMemo(() => new Date().getFullYear(), []); const yearData = useMemo(() => filterDataByYear(data, selectedYear), [data, selectedYear]); const months = useMemo(() => generateMonthsForYear(selectedYear), [selectedYear]); @@ -23,71 +92,28 @@ export const YearCalendar = ({ selectedYear, data, onYearChange, onDateClick, cl const canGoPrev = selectedYear > MIN_YEAR; const canGoNext = selectedYear < getMaxYear(); - const isCurrentYear = selectedYear === currentYear; - - const handlePrevYear = () => canGoPrev && onYearChange(selectedYear - 1); - const handleNextYear = () => canGoNext && onYearChange(selectedYear + 1); - const handleToday = () => onYearChange(currentYear); return ( -
-
-

{selectedYear}

- -
- - - - - -
-
+
+ canGoPrev && onYearChange(selectedYear - 1)} + onNext={() => canGoNext && onYearChange(selectedYear + 1)} + onToday={() => onYearChange(currentYear)} + canGoPrev={canGoPrev} + canGoNext={canGoNext} + /> -
-
- {months.map((month) => ( -
-
{getMonthLabel(month)}
- -
- ))} -
+
+ {months.map((month) => ( + + ))}
-
+
); -}; +}); + +YearCalendar.displayName = "YearCalendar"; diff --git a/web/src/components/ActivityCalendar/constants.ts b/web/src/components/ActivityCalendar/constants.ts index e5cab3f08..fdba8f806 100644 --- a/web/src/components/ActivityCalendar/constants.ts +++ b/web/src/components/ActivityCalendar/constants.ts @@ -14,22 +14,22 @@ export const INTENSITY_THRESHOLDS = { } as const; export const CELL_STYLES = { - HIGH: "bg-primary text-primary-foreground shadow-sm", - MEDIUM: "bg-primary/80 text-primary-foreground shadow-sm", - LOW: "bg-primary/60 text-primary-foreground shadow-sm", - MINIMAL: "bg-primary/40 text-foreground", - EMPTY: "bg-secondary/30 text-muted-foreground hover:bg-secondary/50", + HIGH: "bg-primary text-primary-foreground shadow-sm border-transparent", + MEDIUM: "bg-primary/85 text-primary-foreground shadow-sm border-transparent", + LOW: "bg-primary/70 text-primary-foreground border-transparent", + MINIMAL: "bg-primary/50 text-foreground border-transparent", + EMPTY: "bg-muted/20 text-muted-foreground hover:bg-muted/30 border-border/10", } as const; export const SMALL_CELL_SIZE = { - font: "text-xs", - dimensions: "w-8 h-8 mx-auto", - borderRadius: "rounded-md", - gap: "gap-1", + font: "text-[11px]", + dimensions: "w-full h-full", + borderRadius: "rounded-lg", + gap: "gap-1.5", } as const; export const DEFAULT_CELL_SIZE = { font: "text-xs", - borderRadius: "rounded-md", - gap: "gap-1.5", + borderRadius: "rounded-lg", + gap: "gap-2", } as const; diff --git a/web/src/components/EditableTimestamp.tsx b/web/src/components/EditableTimestamp.tsx new file mode 100644 index 000000000..73b840247 --- /dev/null +++ b/web/src/components/EditableTimestamp.tsx @@ -0,0 +1,104 @@ +import { Timestamp, timestampDate } from "@bufbuild/protobuf/wkt"; +import { PencilIcon } from "lucide-react"; +import { useEffect, useRef, useState } from "react"; +import toast from "react-hot-toast"; +import { cn } from "@/lib/utils"; + +interface Props { + timestamp: Timestamp | undefined; + onChange: (date: Date) => void; + className?: string; +} + +const EditableTimestamp = ({ timestamp, onChange, className }: Props) => { + const [isEditing, setIsEditing] = useState(false); + const [inputValue, setInputValue] = useState(""); + const inputRef = useRef(null); + + const date = timestamp ? timestampDate(timestamp) : new Date(); + const displayValue = date.toLocaleString(); + + // Format date for datetime-local input (YYYY-MM-DDTHH:mm) + const formatForInput = (d: Date): string => { + const year = d.getFullYear(); + const month = String(d.getMonth() + 1).padStart(2, "0"); + const day = String(d.getDate()).padStart(2, "0"); + const hours = String(d.getHours()).padStart(2, "0"); + const minutes = String(d.getMinutes()).padStart(2, "0"); + return `${year}-${month}-${day}T${hours}:${minutes}`; + }; + + useEffect(() => { + if (isEditing && inputRef.current) { + inputRef.current.focus(); + inputRef.current.showPicker?.(); // Open datetime picker if available + } + }, [isEditing]); + + const handleEdit = () => { + setInputValue(formatForInput(date)); + setIsEditing(true); + }; + + const handleSave = () => { + if (!inputValue) { + setIsEditing(false); + return; + } + + const newDate = new Date(inputValue); + if (isNaN(newDate.getTime())) { + toast.error("Invalid date format"); + return; + } + + onChange(newDate); + setIsEditing(false); + }; + + const handleCancel = () => { + setIsEditing(false); + setInputValue(""); + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + handleSave(); + } else if (e.key === "Escape") { + handleCancel(); + } + }; + + if (isEditing) { + return ( + setInputValue(e.target.value)} + onBlur={handleSave} + onKeyDown={handleKeyDown} + className={cn( + "w-full px-2 py-1.5 text-sm text-foreground bg-background rounded-md border border-border outline-none transition-all focus:border-ring focus:ring-1 focus:ring-ring/20", + className, + )} + /> + ); + } + + return ( + + ); +}; + +export default EditableTimestamp; diff --git a/web/src/components/LocaleSelect.tsx b/web/src/components/LocaleSelect.tsx index a5aa48c8e..55b52b214 100644 --- a/web/src/components/LocaleSelect.tsx +++ b/web/src/components/LocaleSelect.tsx @@ -2,7 +2,7 @@ import { GlobeIcon } from "lucide-react"; import { FC } from "react"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import { locales } from "@/i18n"; -import { getLocaleDisplayName } from "@/utils/i18n"; +import { getLocaleDisplayName, loadLocale } from "@/utils/i18n"; interface Props { value: Locale; @@ -13,6 +13,9 @@ const LocaleSelect: FC = (props: Props) => { const { onChange, value } = props; const handleSelectChange = async (locale: Locale) => { + // Apply locale globally immediately + loadLocale(locale); + // Also notify parent component onChange(locale); }; diff --git a/web/src/components/MasonryView/MasonryView.tsx b/web/src/components/MasonryView/MasonryView.tsx index dd87c68a4..f20067b16 100644 --- a/web/src/components/MasonryView/MasonryView.tsx +++ b/web/src/components/MasonryView/MasonryView.tsx @@ -10,10 +10,10 @@ const MasonryView = ({ memoList, renderer, prefixElement, listMode = false }: Ma const { columns, distribution, handleHeightChange } = useMasonryLayout(memoList, listMode, containerRef, prefixElementRef); - // Create render context: automatically enable compact mode when multiple columns + // Create render context: always enable compact mode for list views const renderContext: MemoRenderContext = useMemo( () => ({ - compact: columns > 1, + compact: true, columns, }), [columns], diff --git a/web/src/components/MemoActionMenu/MemoActionMenu.tsx b/web/src/components/MemoActionMenu/MemoActionMenu.tsx index afe2c0ce2..aaa6d386d 100644 --- a/web/src/components/MemoActionMenu/MemoActionMenu.tsx +++ b/web/src/components/MemoActionMenu/MemoActionMenu.tsx @@ -8,7 +8,6 @@ import { FileTextIcon, LinkIcon, MoreVerticalIcon, - SquareCheckIcon, TrashIcon, } from "lucide-react"; import { useState } from "react"; @@ -25,7 +24,6 @@ import { } from "@/components/ui/dropdown-menu"; import { State } from "@/types/proto/api/v1/common_pb"; import { useTranslate } from "@/utils/i18n"; -import { hasCompletedTasks } from "@/utils/markdown-manipulation"; import { useMemoActionHandlers } from "./hooks"; import type { MemoActionMenuProps } from "./types"; @@ -35,10 +33,8 @@ const MemoActionMenu = (props: MemoActionMenuProps) => { // Dialog state const [deleteDialogOpen, setDeleteDialogOpen] = useState(false); - const [removeTasksDialogOpen, setRemoveTasksDialogOpen] = useState(false); // Derived state - const hasCompletedTaskList = hasCompletedTasks(memo.content); const isComment = Boolean(memo.parent); const isArchived = memo.state === State.ARCHIVED; @@ -51,13 +47,10 @@ const MemoActionMenu = (props: MemoActionMenuProps) => { handleCopyContent, handleDeleteMemoClick, confirmDeleteMemo, - handleRemoveCompletedTaskListItemsClick, - confirmRemoveCompletedTaskListItems, } = useMemoActionHandlers({ memo, onEdit: props.onEdit, setDeleteDialogOpen, - setRemoveTasksDialogOpen, }); return ( @@ -107,14 +100,6 @@ const MemoActionMenu = (props: MemoActionMenuProps) => { {/* Write actions (non-readonly) */} {!readonly && ( <> - {/* Remove completed tasks (non-archived, non-comment, has completed tasks) */} - {!isArchived && !isComment && hasCompletedTaskList && ( - - - {t("memo.remove-completed-task-list-items")} - - )} - {/* Archive/Restore (non-comment) */} {!isComment && ( @@ -143,17 +128,6 @@ const MemoActionMenu = (props: MemoActionMenuProps) => { onConfirm={confirmDeleteMemo} confirmVariant="destructive" /> - - {/* Remove completed tasks confirmation */} - ); }; diff --git a/web/src/components/MemoActionMenu/hooks.ts b/web/src/components/MemoActionMenu/hooks.ts index 18db5b28d..67a889f36 100644 --- a/web/src/components/MemoActionMenu/hooks.ts +++ b/web/src/components/MemoActionMenu/hooks.ts @@ -11,16 +11,14 @@ import { handleError } from "@/lib/error"; import { State } from "@/types/proto/api/v1/common_pb"; import type { Memo } from "@/types/proto/api/v1/memo_service_pb"; import { useTranslate } from "@/utils/i18n"; -import { removeCompletedTasks } from "@/utils/markdown-manipulation"; interface UseMemoActionHandlersOptions { memo: Memo; onEdit?: () => void; setDeleteDialogOpen: (open: boolean) => void; - setRemoveTasksDialogOpen: (open: boolean) => void; } -export const useMemoActionHandlers = ({ memo, onEdit, setDeleteDialogOpen, setRemoveTasksDialogOpen }: UseMemoActionHandlersOptions) => { +export const useMemoActionHandlers = ({ memo, onEdit, setDeleteDialogOpen }: UseMemoActionHandlersOptions) => { const t = useTranslate(); const location = useLocation(); const navigateTo = useNavigateTo(); @@ -108,23 +106,6 @@ export const useMemoActionHandlers = ({ memo, onEdit, setDeleteDialogOpen, setRe memoUpdatedCallback(); }, [memo.name, t, isInMemoDetailPage, navigateTo, memoUpdatedCallback, deleteMemo]); - const handleRemoveCompletedTaskListItemsClick = useCallback(() => { - setRemoveTasksDialogOpen(true); - }, [setRemoveTasksDialogOpen]); - - const confirmRemoveCompletedTaskListItems = useCallback(async () => { - const newContent = removeCompletedTasks(memo.content); - await updateMemo({ - update: { - name: memo.name, - content: newContent, - }, - updateMask: ["content"], - }); - toast.success(t("message.remove-completed-task-list-items-successfully")); - memoUpdatedCallback(); - }, [memo.name, memo.content, t, memoUpdatedCallback, updateMemo]); - return { handleTogglePinMemoBtnClick, handleEditMemoClick, @@ -133,7 +114,5 @@ export const useMemoActionHandlers = ({ memo, onEdit, setDeleteDialogOpen, setRe handleCopyContent, handleDeleteMemoClick, confirmDeleteMemo, - handleRemoveCompletedTaskListItemsClick, - confirmRemoveCompletedTaskListItems, }; }; diff --git a/web/src/components/MemoContent/CodeBlock.tsx b/web/src/components/MemoContent/CodeBlock.tsx index 86bb4a1ef..8dfa91a30 100644 --- a/web/src/components/MemoContent/CodeBlock.tsx +++ b/web/src/components/MemoContent/CodeBlock.tsx @@ -6,14 +6,15 @@ import { useAuth } from "@/contexts/AuthContext"; import { cn } from "@/lib/utils"; import { getThemeWithFallback, resolveTheme } from "@/utils/theme"; import { MermaidBlock } from "./MermaidBlock"; +import type { ReactMarkdownProps } from "./markdown/types"; import { extractCodeContent, extractLanguage } from "./utils"; -interface CodeBlockProps { +interface CodeBlockProps extends ReactMarkdownProps { children?: React.ReactNode; className?: string; } -export const CodeBlock = ({ children, className, ...props }: CodeBlockProps) => { +export const CodeBlock = ({ children, className, node: _node, ...props }: CodeBlockProps) => { const { userGeneralSetting } = useAuth(); const [copied, setCopied] = useState(false); @@ -114,20 +115,41 @@ export const CodeBlock = ({ children, className, ...props }: CodeBlockProps) => }; return ( -
-      
- {language} +
+      {/* Header with language label and copy button */}
+      
+ {language || "text"}
-
- + + {/* Code content */} +
+
); diff --git a/web/src/components/MemoContent/MermaidBlock.tsx b/web/src/components/MemoContent/MermaidBlock.tsx index 48ec20bb1..d7511e65e 100644 --- a/web/src/components/MemoContent/MermaidBlock.tsx +++ b/web/src/components/MemoContent/MermaidBlock.tsx @@ -86,7 +86,7 @@ export const MermaidBlock = ({ children, className }: MermaidBlockProps) => { return (
); diff --git a/web/src/components/MemoContent/Table.tsx b/web/src/components/MemoContent/Table.tsx new file mode 100644 index 000000000..45d0cee93 --- /dev/null +++ b/web/src/components/MemoContent/Table.tsx @@ -0,0 +1,83 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./markdown/types"; + +interface TableProps extends React.HTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +export const Table = ({ children, className, node: _node, ...props }: TableProps) => { + return ( +
+ + {children} +
+
+ ); +}; + +interface TableHeadProps extends React.HTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +export const TableHead = ({ children, className, node: _node, ...props }: TableHeadProps) => { + return ( + + {children} + + ); +}; + +interface TableBodyProps extends React.HTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +export const TableBody = ({ children, className, node: _node, ...props }: TableBodyProps) => { + return ( + + {children} + + ); +}; + +interface TableRowProps extends React.HTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +export const TableRow = ({ children, className, node: _node, ...props }: TableRowProps) => { + return ( + + {children} + + ); +}; + +interface TableHeaderCellProps extends React.ThHTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +export const TableHeaderCell = ({ children, className, node: _node, ...props }: TableHeaderCellProps) => { + return ( + + {children} + + ); +}; + +interface TableCellProps extends React.TdHTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +export const TableCell = ({ children, className, node: _node, ...props }: TableCellProps) => { + return ( + + {children} + + ); +}; diff --git a/web/src/components/MemoContent/Tag.tsx b/web/src/components/MemoContent/Tag.tsx index 5fa87d753..966a90418 100644 --- a/web/src/components/MemoContent/Tag.tsx +++ b/web/src/components/MemoContent/Tag.tsx @@ -48,7 +48,7 @@ export const Tag: React.FC = ({ "data-tag": dataTag, children, classNa return ( { - node?: Element; // AST node from react-markdown +interface TaskListItemProps extends React.InputHTMLAttributes, ReactMarkdownProps { checked?: boolean; } -export const TaskListItem: React.FC = ({ checked, ...props }) => { +export const TaskListItem: React.FC = ({ checked, node: _node, ...props }) => { const { memo } = useMemoViewContext(); const { readonly } = useMemoViewDerived(); const checkboxRef = useRef(null); @@ -35,14 +35,19 @@ export const TaskListItem: React.FC = ({ checked, ...props }) if (taskIndexStr !== null) { taskIndex = parseInt(taskIndexStr); } else { - // Fallback: Calculate index by counting ALL task list items in the memo - // Find the markdown-content container by traversing up from the list item - const container = listItem.closest(".markdown-content"); - if (!container) { - return; + // Fallback: Calculate index by counting task list items + // Walk up to find the parent element with all task items + let searchRoot = listItem.parentElement; + while (searchRoot && !searchRoot.classList.contains(TASK_LIST_CLASS)) { + searchRoot = searchRoot.parentElement; } - const allTaskItems = container.querySelectorAll("li.task-list-item"); + // If not found, search from the document root + if (!searchRoot) { + searchRoot = document.body; + } + + const allTaskItems = searchRoot.querySelectorAll(`li.${TASK_LIST_ITEM_CLASS}`); for (let i = 0; i < allTaskItems.length; i++) { if (allTaskItems[i] === listItem) { taskIndex = i; diff --git a/web/src/components/MemoContent/constants.ts b/web/src/components/MemoContent/constants.ts index 2a7c8293d..237b1e6b8 100644 --- a/web/src/components/MemoContent/constants.ts +++ b/web/src/components/MemoContent/constants.ts @@ -1,6 +1,16 @@ import { defaultSchema } from "rehype-sanitize"; -export const MAX_DISPLAY_HEIGHT = 256; +// Class names added by remark-gfm for task lists +export const TASK_LIST_CLASS = "contains-task-list"; +export const TASK_LIST_ITEM_CLASS = "task-list-item"; + +// Compact mode display settings +export const COMPACT_MODE_CONFIG = { + maxHeightVh: 60, // 60% of viewport height + gradientHeight: "h-24", // Tailwind class for gradient overlay +} as const; + +export const getMaxDisplayHeight = () => window.innerHeight * (COMPACT_MODE_CONFIG.maxHeightVh / 100); export const COMPACT_STATES: Record<"ALL" | "SNIPPET", { textKey: string; next: "ALL" | "SNIPPET" }> = { ALL: { textKey: "memo.show-more", next: "SNIPPET" }, diff --git a/web/src/components/MemoContent/hooks.ts b/web/src/components/MemoContent/hooks.ts index 2f7c82ca1..bc8187472 100644 --- a/web/src/components/MemoContent/hooks.ts +++ b/web/src/components/MemoContent/hooks.ts @@ -1,5 +1,5 @@ import { useCallback, useEffect, useRef, useState } from "react"; -import { COMPACT_STATES, MAX_DISPLAY_HEIGHT } from "./constants"; +import { COMPACT_STATES, getMaxDisplayHeight } from "./constants"; import type { ContentCompactView } from "./types"; export const useCompactMode = (enabled: boolean) => { @@ -8,7 +8,8 @@ export const useCompactMode = (enabled: boolean) => { useEffect(() => { if (!enabled || !containerRef.current) return; - if (containerRef.current.getBoundingClientRect().height > MAX_DISPLAY_HEIGHT) { + const maxHeight = getMaxDisplayHeight(); + if (containerRef.current.getBoundingClientRect().height > maxHeight) { setMode("ALL"); } }, [enabled]); diff --git a/web/src/components/MemoContent/index.tsx b/web/src/components/MemoContent/index.tsx index f0114909b..b60ffffe6 100644 --- a/web/src/components/MemoContent/index.tsx +++ b/web/src/components/MemoContent/index.tsx @@ -1,4 +1,5 @@ import type { Element } from "hast"; +import { ChevronDown, ChevronUp } from "lucide-react"; import { memo } from "react"; import ReactMarkdown from "react-markdown"; import rehypeKatex from "rehype-katex"; @@ -14,8 +15,10 @@ import { remarkPreserveType } from "@/utils/remark-plugins/remark-preserve-type" import { remarkTag } from "@/utils/remark-plugins/remark-tag"; import { CodeBlock } from "./CodeBlock"; import { isTagNode, isTaskListItemNode } from "./ConditionalComponent"; -import { SANITIZE_SCHEMA } from "./constants"; +import { COMPACT_MODE_CONFIG, SANITIZE_SCHEMA } from "./constants"; import { useCompactLabel, useCompactMode } from "./hooks"; +import { Blockquote, Heading, HorizontalRule, Image, InlineCode, Link, List, ListItem, Paragraph } from "./markdown"; +import { Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow } from "./Table"; import { Tag } from "./Tag"; import { TaskListItem } from "./TaskListItem"; import type { MemoContentProps } from "./types"; @@ -36,16 +39,18 @@ const MemoContent = (props: MemoContentProps) => {
*:last-child]:mb-0", + showCompactMode === "ALL" && "overflow-hidden", contentClassName, )} + style={showCompactMode === "ALL" ? { maxHeight: `${COMPACT_MODE_CONFIG.maxHeightVh}vh` } : undefined} onMouseUp={onClick} onDoubleClick={onDoubleClick} > & { node?: Element }) => { @@ -61,28 +66,61 @@ const MemoContent = (props: MemoContentProps) => { } return ; }) as React.ComponentType>, - pre: CodeBlock, - a: ({ href, children, ...aProps }) => ( - + // Headings + h1: ({ children }) => {children}, + h2: ({ children }) => {children}, + h3: ({ children }) => {children}, + h4: ({ children }) => {children}, + h5: ({ children }) => {children}, + h6: ({ children }) => {children}, + // Block elements + p: ({ children }) => {children}, + blockquote: ({ children }) =>
{children}
, + hr: () => , + // Lists + ul: ({ children, ...props }) => {children}, + ol: ({ children, ...props }) => ( + {children} -
+ ), + li: ({ children, ...props }) => {children}, + // Inline elements + a: ({ children, ...props }) => {children}, + code: ({ children }) => {children}, + img: ({ ...props }) => , + // Code blocks + pre: CodeBlock, + // Tables + table: ({ children }) => {children}
, + thead: ({ children }) => {children}, + tbody: ({ children }) => {children}, + tr: ({ children }) => {children}, + th: ({ children, ...props }) => {children}, + td: ({ children, ...props }) => {children}, }} > {content}
+ {showCompactMode === "ALL" && ( +
+ )}
- {showCompactMode === "ALL" && ( -
- )} {showCompactMode !== undefined && ( -
+
)} diff --git a/web/src/components/MemoContent/markdown/Blockquote.tsx b/web/src/components/MemoContent/markdown/Blockquote.tsx new file mode 100644 index 000000000..c8ed5fc31 --- /dev/null +++ b/web/src/components/MemoContent/markdown/Blockquote.tsx @@ -0,0 +1,17 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface BlockquoteProps extends React.BlockquoteHTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +/** + * Blockquote component with left border accent + */ +export const Blockquote = ({ children, className, node: _node, ...props }: BlockquoteProps) => { + return ( +
+ {children} +
+ ); +}; diff --git a/web/src/components/MemoContent/markdown/Heading.tsx b/web/src/components/MemoContent/markdown/Heading.tsx new file mode 100644 index 000000000..000589abc --- /dev/null +++ b/web/src/components/MemoContent/markdown/Heading.tsx @@ -0,0 +1,30 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface HeadingProps extends React.HTMLAttributes, ReactMarkdownProps { + level: 1 | 2 | 3 | 4 | 5 | 6; + children: React.ReactNode; +} + +/** + * Heading component for h1-h6 elements + * Renders semantic heading levels with consistent styling + */ +export const Heading = ({ level, children, className, node: _node, ...props }: HeadingProps) => { + const Component = `h${level}` as const; + + const levelClasses = { + 1: "text-3xl font-bold border-b border-border pb-2", + 2: "text-2xl font-semibold border-b border-border pb-1.5", + 3: "text-xl font-semibold", + 4: "text-lg font-semibold", + 5: "text-base font-semibold", + 6: "text-base font-medium text-muted-foreground", + }; + + return ( + + {children} + + ); +}; diff --git a/web/src/components/MemoContent/markdown/HorizontalRule.tsx b/web/src/components/MemoContent/markdown/HorizontalRule.tsx new file mode 100644 index 000000000..dc798b778 --- /dev/null +++ b/web/src/components/MemoContent/markdown/HorizontalRule.tsx @@ -0,0 +1,11 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface HorizontalRuleProps extends React.HTMLAttributes, ReactMarkdownProps {} + +/** + * Horizontal rule separator + */ +export const HorizontalRule = ({ className, node: _node, ...props }: HorizontalRuleProps) => { + return
; +}; diff --git a/web/src/components/MemoContent/markdown/Image.tsx b/web/src/components/MemoContent/markdown/Image.tsx new file mode 100644 index 000000000..05def40f7 --- /dev/null +++ b/web/src/components/MemoContent/markdown/Image.tsx @@ -0,0 +1,12 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface ImageProps extends React.ImgHTMLAttributes, ReactMarkdownProps {} + +/** + * Image component for markdown images + * Responsive with rounded corners + */ +export const Image = ({ className, alt, node: _node, ...props }: ImageProps) => { + return {alt}; +}; diff --git a/web/src/components/MemoContent/markdown/InlineCode.tsx b/web/src/components/MemoContent/markdown/InlineCode.tsx new file mode 100644 index 000000000..945dc1c03 --- /dev/null +++ b/web/src/components/MemoContent/markdown/InlineCode.tsx @@ -0,0 +1,17 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface InlineCodeProps extends React.HTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +/** + * Inline code component with background and monospace font + */ +export const InlineCode = ({ children, className, node: _node, ...props }: InlineCodeProps) => { + return ( + + {children} + + ); +}; diff --git a/web/src/components/MemoContent/markdown/Link.tsx b/web/src/components/MemoContent/markdown/Link.tsx new file mode 100644 index 000000000..d305fdaf7 --- /dev/null +++ b/web/src/components/MemoContent/markdown/Link.tsx @@ -0,0 +1,27 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface LinkProps extends React.AnchorHTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +/** + * Link component for external links + * Opens in new tab with security attributes + */ +export const Link = ({ children, className, href, node: _node, ...props }: LinkProps) => { + return ( + + {children} + + ); +}; diff --git a/web/src/components/MemoContent/markdown/List.tsx b/web/src/components/MemoContent/markdown/List.tsx new file mode 100644 index 000000000..b9969bfcf --- /dev/null +++ b/web/src/components/MemoContent/markdown/List.tsx @@ -0,0 +1,67 @@ +import { cn } from "@/lib/utils"; +import { TASK_LIST_CLASS, TASK_LIST_ITEM_CLASS } from "../constants"; +import type { ReactMarkdownProps } from "./types"; + +interface ListProps extends React.HTMLAttributes, ReactMarkdownProps { + ordered?: boolean; + children: React.ReactNode; +} + +/** + * List component for both regular and task lists (GFM) + * Detects task lists via the "contains-task-list" class added by remark-gfm + */ +export const List = ({ ordered, children, className, node: _node, ...domProps }: ListProps) => { + const Component = ordered ? "ol" : "ul"; + const isTaskList = className?.includes(TASK_LIST_CLASS); + + return ( + + {children} + + ); +}; + +interface ListItemProps extends React.LiHTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +/** + * List item component for both regular and task list items + * Detects task items via the "task-list-item" class added by remark-gfm + * Applies specialized styling for task checkboxes + */ +export const ListItem = ({ children, className, node: _node, ...domProps }: ListItemProps) => { + const isTaskListItem = className?.includes(TASK_LIST_ITEM_CLASS); + + if (isTaskListItem) { + return ( +
  • button]:mr-2 [&>button]:align-middle", + "[&>p]:inline [&>p]:m-0", + `[&>.${TASK_LIST_CLASS}]:pl-6`, + className, + )} + {...domProps} + > + {children} +
  • + ); + } + + return ( +
  • + {children} +
  • + ); +}; diff --git a/web/src/components/MemoContent/markdown/Paragraph.tsx b/web/src/components/MemoContent/markdown/Paragraph.tsx new file mode 100644 index 000000000..ecf5e67e6 --- /dev/null +++ b/web/src/components/MemoContent/markdown/Paragraph.tsx @@ -0,0 +1,17 @@ +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface ParagraphProps extends React.HTMLAttributes, ReactMarkdownProps { + children: React.ReactNode; +} + +/** + * Paragraph component with compact spacing + */ +export const Paragraph = ({ children, className, node: _node, ...props }: ParagraphProps) => { + return ( +

    + {children} +

    + ); +}; diff --git a/web/src/components/MemoContent/markdown/README.md b/web/src/components/MemoContent/markdown/README.md new file mode 100644 index 000000000..a6ba08cc7 --- /dev/null +++ b/web/src/components/MemoContent/markdown/README.md @@ -0,0 +1,97 @@ +# Markdown Components + +Modern, type-safe React components for rendering markdown content via react-markdown. + +## Architecture + +### Component-Based Rendering +Following patterns from popular AI chat apps (ChatGPT, Claude, Perplexity), we use React components instead of CSS selectors for markdown rendering. This provides: + +- **Type Safety**: Full TypeScript support with proper prop types +- **Maintainability**: Components are easier to test, modify, and understand +- **Performance**: No CSS specificity conflicts, cleaner DOM +- **Modularity**: Each element is independently styled and documented + +### Type System + +All components extend `ReactMarkdownProps` which includes the AST `node` prop passed by react-markdown. This is explicitly destructured as `node: _node` to: +1. Filter it from DOM props (avoids `node="[object Object]"` in HTML) +2. Keep it available for advanced use cases (e.g., detecting task lists) +3. Maintain type safety without `as any` casts + +### GFM Task Lists + +Task lists (from remark-gfm) are handled by: +- **Detection**: `contains-task-list` and `task-list-item` classes from remark-gfm +- **Styling**: Tailwind utilities with arbitrary variants for nested elements +- **Checkboxes**: Custom `TaskListItem` component with Radix UI checkbox +- **Interactivity**: Updates memo content via `toggleTaskAtIndex` utility + +### Component Patterns + +Each component follows this structure: +```tsx +import { cn } from "@/lib/utils"; +import type { ReactMarkdownProps } from "./types"; + +interface ComponentProps extends React.HTMLAttributes, ReactMarkdownProps { + children?: React.ReactNode; + // component-specific props +} + +/** + * JSDoc description + */ +export const Component = ({ children, className, node: _node, ...props }: ComponentProps) => { + return ( + + {children} + + ); +}; +``` + +## Components + +| Component | Element | Purpose | +|-----------|---------|---------| +| `Heading` | h1-h6 | Semantic headings with level-based styling | +| `Paragraph` | p | Compact paragraphs with consistent spacing | +| `Link` | a | External links with security attributes | +| `List` | ul/ol | Regular and GFM task lists | +| `ListItem` | li | List items with task checkbox support | +| `Blockquote` | blockquote | Quotes with left border accent | +| `InlineCode` | code | Inline code with background | +| `Image` | img | Responsive images with rounded corners | +| `HorizontalRule` | hr | Section separators | + +## Styling Approach + +- **Tailwind CSS**: All styling uses Tailwind utilities +- **Design Tokens**: Colors use CSS variables (e.g., `--primary`, `--muted-foreground`) +- **Responsive**: Max-width constraints, responsive images +- **Accessibility**: Semantic HTML, proper ARIA attributes via Radix UI + +## Integration + +Components are mapped to HTML elements in `MemoContent/index.tsx`: + +```tsx + {children}, + p: ({ children, ...props }) => {children}, + // ... more mappings + }} +> + {content} + +``` + +## Future Enhancements + +- [ ] Syntax highlighting themes for code blocks +- [ ] Table sorting/filtering interactions +- [ ] Image lightbox/zoom functionality +- [ ] Collapsible sections for long content +- [ ] Copy button for code blocks diff --git a/web/src/components/MemoContent/markdown/index.ts b/web/src/components/MemoContent/markdown/index.ts new file mode 100644 index 000000000..e395d51eb --- /dev/null +++ b/web/src/components/MemoContent/markdown/index.ts @@ -0,0 +1,8 @@ +export { Blockquote } from "./Blockquote"; +export { Heading } from "./Heading"; +export { HorizontalRule } from "./HorizontalRule"; +export { Image } from "./Image"; +export { InlineCode } from "./InlineCode"; +export { Link } from "./Link"; +export { List, ListItem } from "./List"; +export { Paragraph } from "./Paragraph"; diff --git a/web/src/components/MemoContent/markdown/types.ts b/web/src/components/MemoContent/markdown/types.ts new file mode 100644 index 000000000..b50d4fcc3 --- /dev/null +++ b/web/src/components/MemoContent/markdown/types.ts @@ -0,0 +1,9 @@ +import type { Element } from "hast"; + +/** + * Props passed by react-markdown to custom components + * Includes the AST node for advanced use cases + */ +export interface ReactMarkdownProps { + node?: Element; +} diff --git a/web/src/components/MemoDetailSidebar/MemoDetailSidebar.tsx b/web/src/components/MemoDetailSidebar/MemoDetailSidebar.tsx index f236d0e84..4a3822afa 100644 --- a/web/src/components/MemoDetailSidebar/MemoDetailSidebar.tsx +++ b/web/src/components/MemoDetailSidebar/MemoDetailSidebar.tsx @@ -1,7 +1,10 @@ import { create } from "@bufbuild/protobuf"; -import { timestampDate } from "@bufbuild/protobuf/wkt"; +import { timestampFromDate } from "@bufbuild/protobuf/wkt"; import { isEqual } from "lodash-es"; import { CheckCircleIcon, Code2Icon, HashIcon, LinkIcon } from "lucide-react"; +import toast from "react-hot-toast"; +import EditableTimestamp from "@/components/EditableTimestamp"; +import { useUpdateMemo } from "@/hooks/useMemoQueries"; import { cn } from "@/lib/utils"; import { Memo, Memo_PropertySchema, MemoRelation_Type } from "@/types/proto/api/v1/memo_service_pb"; import { useTranslate } from "@/utils/i18n"; @@ -15,87 +18,96 @@ interface Props { const MemoDetailSidebar = ({ memo, className, parentPage }: Props) => { const t = useTranslate(); + const { mutate: updateMemo } = useUpdateMemo(); const property = create(Memo_PropertySchema, memo.property || {}); - const hasSpecialProperty = property.hasLink || property.hasTaskList || property.hasCode || property.hasIncompleteTasks; - const shouldShowRelationGraph = memo.relations.filter((r) => r.type === MemoRelation_Type.REFERENCE).length > 0; + const hasSpecialProperty = property.hasLink || property.hasTaskList || property.hasCode; + const hasReferenceRelations = memo.relations.some((r) => r.type === MemoRelation_Type.REFERENCE); + + const handleUpdateTimestamp = (field: "createTime" | "updateTime", date: Date) => { + const currentTimestamp = memo[field]; + const newTimestamp = timestampFromDate(date); + if (isEqual(currentTimestamp, newTimestamp)) { + return; + } + updateMemo( + { + update: { name: memo.name, [field]: newTimestamp }, + updateMask: [field === "createTime" ? "create_time" : "update_time"], + }, + { + onSuccess: () => toast.success("Updated successfully"), + onError: (error) => toast.error(error.message), + }, + ); + }; return (